374 lines
12 KiB
Python
374 lines
12 KiB
Python
"""Usage Store — Persistent usage tracking with Redis Hash backend.
|
|
|
|
Provides UsageStore Protocol with InMemoryUsageStore and RedisUsageStore
|
|
backends. Replaces the in-memory list in UsageTracker with a pluggable
|
|
store that survives restarts and supports multi-instance deployment.
|
|
|
|
Key schema (Redis):
|
|
agentkit:usage:{date} → Hash: {agent_name:model → JSON(UsageBucket)}
|
|
agentkit:usage_records:{date} → List: JSON(UsageRecord) with LTRIM
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any, Protocol, runtime_checkable
|
|
|
|
from agentkit.llm.protocol import TokenUsage
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class UsageRecord:
|
|
"""使用量记录"""
|
|
|
|
agent_name: str
|
|
model: str
|
|
prompt_tokens: int
|
|
completion_tokens: int
|
|
total_tokens: int
|
|
cost: float
|
|
latency_ms: float
|
|
timestamp: str = "" # ISO 8601 string for JSON serialization
|
|
|
|
def __post_init__(self):
|
|
if not self.timestamp:
|
|
self.timestamp = datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
@dataclass
|
|
class UsageBucket:
|
|
"""Aggregated usage for an agent+model pair on a given date."""
|
|
|
|
prompt_tokens: int = 0
|
|
completion_tokens: int = 0
|
|
total_tokens: int = 0
|
|
cost: float = 0.0
|
|
count: int = 0
|
|
|
|
|
|
@dataclass
|
|
class UsageSummary:
|
|
"""使用量汇总"""
|
|
|
|
total_tokens: int = 0
|
|
total_cost: float = 0.0
|
|
by_model: dict[str, dict[str, int | float]] = field(default_factory=dict)
|
|
records: list[UsageRecord] = field(default_factory=list)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# UsageStore Protocol
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@runtime_checkable
|
|
class UsageStore(Protocol):
|
|
"""Persistent usage store interface."""
|
|
|
|
def record(
|
|
self,
|
|
agent_name: str,
|
|
model: str,
|
|
usage: TokenUsage,
|
|
cost: float,
|
|
latency_ms: float,
|
|
) -> None:
|
|
"""Record a usage event."""
|
|
...
|
|
|
|
def get_usage(
|
|
self,
|
|
agent_name: str | None = None,
|
|
start_time: datetime | None = None,
|
|
end_time: datetime | None = None,
|
|
) -> UsageSummary:
|
|
"""Query usage summary."""
|
|
...
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# InMemoryUsageStore
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class InMemoryUsageStore:
|
|
"""In-memory usage store (drop-in replacement for old UsageTracker)."""
|
|
|
|
MAX_RECORDS = 10000
|
|
|
|
def __init__(self):
|
|
self._records: list[UsageRecord] = []
|
|
|
|
def record(
|
|
self,
|
|
agent_name: str,
|
|
model: str,
|
|
usage: TokenUsage,
|
|
cost: float,
|
|
latency_ms: float,
|
|
) -> None:
|
|
rec = UsageRecord(
|
|
agent_name=agent_name,
|
|
model=model,
|
|
prompt_tokens=usage.prompt_tokens,
|
|
completion_tokens=usage.completion_tokens,
|
|
total_tokens=usage.total_tokens,
|
|
cost=cost,
|
|
latency_ms=latency_ms,
|
|
)
|
|
self._records.append(rec)
|
|
if len(self._records) > self.MAX_RECORDS:
|
|
self._records = self._records[-self.MAX_RECORDS:]
|
|
|
|
def get_usage(
|
|
self,
|
|
agent_name: str | None = None,
|
|
start_time: datetime | None = None,
|
|
end_time: datetime | None = None,
|
|
) -> UsageSummary:
|
|
filtered = self._records
|
|
|
|
if agent_name is not None:
|
|
filtered = [r for r in filtered if r.agent_name == agent_name]
|
|
if start_time is not None:
|
|
filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) >= start_time]
|
|
if end_time is not None:
|
|
filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) <= end_time]
|
|
|
|
if not filtered:
|
|
return UsageSummary()
|
|
|
|
total_tokens = sum(r.total_tokens for r in filtered)
|
|
total_cost = sum(r.cost for r in filtered)
|
|
|
|
by_model: dict[str, dict[str, int | float]] = {}
|
|
for r in filtered:
|
|
if r.model not in by_model:
|
|
by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0}
|
|
by_model[r.model]["total_tokens"] += r.total_tokens
|
|
by_model[r.model]["total_cost"] += r.cost
|
|
by_model[r.model]["count"] += 1
|
|
|
|
return UsageSummary(
|
|
total_tokens=total_tokens,
|
|
total_cost=total_cost,
|
|
by_model=by_model,
|
|
records=filtered,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RedisUsageStore
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class RedisUsageStore:
|
|
"""Redis-backed usage store using Hash per date for O(1) writes.
|
|
|
|
Key schema:
|
|
agentkit:usage:{YYYY-MM-DD} → Hash: {agent:model → JSON(UsageBucket)}
|
|
agentkit:usage_records:{YYYY-MM-DD} → List: JSON(UsageRecord) with LTRIM
|
|
"""
|
|
|
|
USAGE_PREFIX = "agentkit:usage:"
|
|
RECORDS_PREFIX = "agentkit:usage_records:"
|
|
MAX_RECORDS_PER_DAY = 50000
|
|
TTL_DAYS = 90 # Auto-expire after 90 days
|
|
|
|
def __init__(self, redis_url: str = "redis://localhost:6379"):
|
|
self._redis_url = redis_url
|
|
self._redis: Any = None
|
|
self._sync_redis: Any = None
|
|
self._fallback: InMemoryUsageStore | None = None
|
|
self._degraded = False
|
|
|
|
async def _get_redis(self):
|
|
if self._redis is None:
|
|
import redis.asyncio as aioredis
|
|
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
|
|
return self._redis
|
|
|
|
def _get_sync_redis(self):
|
|
"""Get or create a persistent sync Redis client (connection pool backed)."""
|
|
if self._sync_redis is None:
|
|
import redis as sync_redis
|
|
self._sync_redis = sync_redis.from_url(
|
|
self._redis_url, decode_responses=True
|
|
)
|
|
return self._sync_redis
|
|
|
|
async def aclose(self) -> None:
|
|
if self._redis is not None:
|
|
await self._redis.aclose()
|
|
self._redis = None
|
|
if self._sync_redis is not None:
|
|
self._sync_redis.close()
|
|
self._sync_redis = None
|
|
|
|
def _degrade_to_fallback(self) -> None:
|
|
if not self._degraded:
|
|
self._degraded = True
|
|
if self._fallback is None:
|
|
self._fallback = InMemoryUsageStore()
|
|
logger.warning("Redis usage store unreachable, degraded to in-memory")
|
|
|
|
def _today_key(self) -> str:
|
|
return datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
|
|
|
def record(
|
|
self,
|
|
agent_name: str,
|
|
model: str,
|
|
usage: TokenUsage,
|
|
cost: float,
|
|
latency_ms: float,
|
|
) -> None:
|
|
"""Record usage — sync wrapper for async Redis.
|
|
|
|
Note: This is a sync method because UsageTracker.record() is sync.
|
|
For Redis, we use a sync Redis client for writes to avoid
|
|
needing an event loop in the caller.
|
|
"""
|
|
if self._degraded and self._fallback is not None:
|
|
self._fallback.record(agent_name, model, usage, cost, latency_ms)
|
|
return
|
|
|
|
try:
|
|
r = self._get_sync_redis()
|
|
|
|
date_key = self._today_key()
|
|
hash_key = f"{self.USAGE_PREFIX}{date_key}"
|
|
list_key = f"{self.RECORDS_PREFIX}{date_key}"
|
|
bucket_field = f"{agent_name}:{model}"
|
|
|
|
# Atomic HINCRBYFLOAT for bucket aggregation
|
|
pipe = r.pipeline()
|
|
pipe.hincrbyfloat(hash_key, f"{bucket_field}:cost", cost)
|
|
pipe.hincrby(hash_key, f"{bucket_field}:prompt_tokens", usage.prompt_tokens)
|
|
pipe.hincrby(hash_key, f"{bucket_field}:completion_tokens", usage.completion_tokens)
|
|
pipe.hincrby(hash_key, f"{bucket_field}:total_tokens", usage.total_tokens)
|
|
pipe.hincrby(hash_key, f"{bucket_field}:count", 1)
|
|
|
|
# Append record
|
|
rec = UsageRecord(
|
|
agent_name=agent_name,
|
|
model=model,
|
|
prompt_tokens=usage.prompt_tokens,
|
|
completion_tokens=usage.completion_tokens,
|
|
total_tokens=usage.total_tokens,
|
|
cost=cost,
|
|
latency_ms=latency_ms,
|
|
)
|
|
pipe.rpush(list_key, json.dumps({
|
|
"agent_name": rec.agent_name,
|
|
"model": rec.model,
|
|
"prompt_tokens": rec.prompt_tokens,
|
|
"completion_tokens": rec.completion_tokens,
|
|
"total_tokens": rec.total_tokens,
|
|
"cost": rec.cost,
|
|
"latency_ms": rec.latency_ms,
|
|
"timestamp": rec.timestamp,
|
|
}))
|
|
pipe.ltrim(list_key, -self.MAX_RECORDS_PER_DAY, -1)
|
|
|
|
# Set TTL on first write of the day
|
|
pipe.expire(hash_key, self.TTL_DAYS * 86400)
|
|
pipe.expire(list_key, self.TTL_DAYS * 86400)
|
|
|
|
pipe.execute()
|
|
except Exception as e:
|
|
logger.warning(f"Redis usage record failed: {e}")
|
|
self._degrade_to_fallback()
|
|
if self._fallback is not None:
|
|
self._fallback.record(agent_name, model, usage, cost, latency_ms)
|
|
|
|
def get_usage(
|
|
self,
|
|
agent_name: str | None = None,
|
|
start_time: datetime | None = None,
|
|
end_time: datetime | None = None,
|
|
) -> UsageSummary:
|
|
"""Query usage summary from Redis."""
|
|
if self._degraded and self._fallback is not None:
|
|
return self._fallback.get_usage(agent_name, start_time, end_time)
|
|
|
|
try:
|
|
r = self._get_sync_redis()
|
|
|
|
# Determine date range to scan
|
|
start = start_time or datetime(2020, 1, 1, tzinfo=timezone.utc)
|
|
end = end_time or datetime.now(timezone.utc)
|
|
|
|
all_records: list[UsageRecord] = []
|
|
# Scan date keys in range
|
|
current = start.date()
|
|
end_date = end.date()
|
|
while current <= end_date:
|
|
list_key = f"{self.RECORDS_PREFIX}{current.isoformat()}"
|
|
raw_records = r.lrange(list_key, 0, -1)
|
|
for raw in raw_records:
|
|
data = json.loads(raw)
|
|
rec = UsageRecord(**data)
|
|
rec_ts = datetime.fromisoformat(rec.timestamp)
|
|
if rec_ts >= start and rec_ts <= end:
|
|
if agent_name is None or rec.agent_name == agent_name:
|
|
all_records.append(rec)
|
|
current = current + timedelta(days=1)
|
|
|
|
if not all_records:
|
|
return UsageSummary()
|
|
|
|
total_tokens = sum(r.total_tokens for r in all_records)
|
|
total_cost = sum(r.cost for r in all_records)
|
|
|
|
by_model: dict[str, dict[str, int | float]] = {}
|
|
for r in all_records:
|
|
if r.model not in by_model:
|
|
by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0}
|
|
by_model[r.model]["total_tokens"] += r.total_tokens
|
|
by_model[r.model]["total_cost"] += r.cost
|
|
by_model[r.model]["count"] += 1
|
|
|
|
return UsageSummary(
|
|
total_tokens=total_tokens,
|
|
total_cost=total_cost,
|
|
by_model=by_model,
|
|
records=all_records,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Redis usage query failed: {e}")
|
|
if self._fallback is not None:
|
|
return self._fallback.get_usage(agent_name, start_time, end_time)
|
|
return UsageSummary()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Factory
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def create_usage_store(
|
|
backend: str = "auto",
|
|
redis_url: str = "redis://localhost:6379",
|
|
) -> UsageStore:
|
|
"""Create a usage store backend.
|
|
|
|
Args:
|
|
backend: "auto" (try Redis, fallback to memory), "redis", "memory".
|
|
redis_url: Redis connection URL.
|
|
|
|
Returns:
|
|
A UsageStore instance.
|
|
"""
|
|
if backend in ("auto", "redis"):
|
|
try:
|
|
import redis # noqa: F401
|
|
return RedisUsageStore(redis_url=redis_url)
|
|
except ImportError:
|
|
logger.warning("redis package not available, falling back to in-memory usage store")
|
|
return InMemoryUsageStore()
|
|
return InMemoryUsageStore()
|