fischer-agentkit/src/agentkit/llm/providers/usage_store.py

374 lines
12 KiB
Python

"""Usage Store — Persistent usage tracking with Redis Hash backend.
Provides UsageStore Protocol with InMemoryUsageStore and RedisUsageStore
backends. Replaces the in-memory list in UsageTracker with a pluggable
store that survives restarts and supports multi-instance deployment.
Key schema (Redis):
agentkit:usage:{date} → Hash: {agent_name:model → JSON(UsageBucket)}
agentkit:usage_records:{date} → List: JSON(UsageRecord) with LTRIM
"""
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import Any, Protocol, runtime_checkable
from agentkit.llm.protocol import TokenUsage
logger = logging.getLogger(__name__)
@dataclass
class UsageRecord:
"""使用量记录"""
agent_name: str
model: str
prompt_tokens: int
completion_tokens: int
total_tokens: int
cost: float
latency_ms: float
timestamp: str = "" # ISO 8601 string for JSON serialization
def __post_init__(self):
if not self.timestamp:
self.timestamp = datetime.now(timezone.utc).isoformat()
@dataclass
class UsageBucket:
"""Aggregated usage for an agent+model pair on a given date."""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
cost: float = 0.0
count: int = 0
@dataclass
class UsageSummary:
"""使用量汇总"""
total_tokens: int = 0
total_cost: float = 0.0
by_model: dict[str, dict[str, int | float]] = field(default_factory=dict)
records: list[UsageRecord] = field(default_factory=list)
# ---------------------------------------------------------------------------
# UsageStore Protocol
# ---------------------------------------------------------------------------
@runtime_checkable
class UsageStore(Protocol):
"""Persistent usage store interface."""
def record(
self,
agent_name: str,
model: str,
usage: TokenUsage,
cost: float,
latency_ms: float,
) -> None:
"""Record a usage event."""
...
def get_usage(
self,
agent_name: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> UsageSummary:
"""Query usage summary."""
...
# ---------------------------------------------------------------------------
# InMemoryUsageStore
# ---------------------------------------------------------------------------
class InMemoryUsageStore:
"""In-memory usage store (drop-in replacement for old UsageTracker)."""
MAX_RECORDS = 10000
def __init__(self):
self._records: list[UsageRecord] = []
def record(
self,
agent_name: str,
model: str,
usage: TokenUsage,
cost: float,
latency_ms: float,
) -> None:
rec = UsageRecord(
agent_name=agent_name,
model=model,
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens,
cost=cost,
latency_ms=latency_ms,
)
self._records.append(rec)
if len(self._records) > self.MAX_RECORDS:
self._records = self._records[-self.MAX_RECORDS:]
def get_usage(
self,
agent_name: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> UsageSummary:
filtered = self._records
if agent_name is not None:
filtered = [r for r in filtered if r.agent_name == agent_name]
if start_time is not None:
filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) >= start_time]
if end_time is not None:
filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) <= end_time]
if not filtered:
return UsageSummary()
total_tokens = sum(r.total_tokens for r in filtered)
total_cost = sum(r.cost for r in filtered)
by_model: dict[str, dict[str, int | float]] = {}
for r in filtered:
if r.model not in by_model:
by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0}
by_model[r.model]["total_tokens"] += r.total_tokens
by_model[r.model]["total_cost"] += r.cost
by_model[r.model]["count"] += 1
return UsageSummary(
total_tokens=total_tokens,
total_cost=total_cost,
by_model=by_model,
records=filtered,
)
# ---------------------------------------------------------------------------
# RedisUsageStore
# ---------------------------------------------------------------------------
class RedisUsageStore:
"""Redis-backed usage store using Hash per date for O(1) writes.
Key schema:
agentkit:usage:{YYYY-MM-DD} → Hash: {agent:model → JSON(UsageBucket)}
agentkit:usage_records:{YYYY-MM-DD} → List: JSON(UsageRecord) with LTRIM
"""
USAGE_PREFIX = "agentkit:usage:"
RECORDS_PREFIX = "agentkit:usage_records:"
MAX_RECORDS_PER_DAY = 50000
TTL_DAYS = 90 # Auto-expire after 90 days
def __init__(self, redis_url: str = "redis://localhost:6379"):
self._redis_url = redis_url
self._redis: Any = None
self._sync_redis: Any = None
self._fallback: InMemoryUsageStore | None = None
self._degraded = False
async def _get_redis(self):
if self._redis is None:
import redis.asyncio as aioredis
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
return self._redis
def _get_sync_redis(self):
"""Get or create a persistent sync Redis client (connection pool backed)."""
if self._sync_redis is None:
import redis as sync_redis
self._sync_redis = sync_redis.from_url(
self._redis_url, decode_responses=True
)
return self._sync_redis
async def aclose(self) -> None:
if self._redis is not None:
await self._redis.aclose()
self._redis = None
if self._sync_redis is not None:
self._sync_redis.close()
self._sync_redis = None
def _degrade_to_fallback(self) -> None:
if not self._degraded:
self._degraded = True
if self._fallback is None:
self._fallback = InMemoryUsageStore()
logger.warning("Redis usage store unreachable, degraded to in-memory")
def _today_key(self) -> str:
return datetime.now(timezone.utc).strftime("%Y-%m-%d")
def record(
self,
agent_name: str,
model: str,
usage: TokenUsage,
cost: float,
latency_ms: float,
) -> None:
"""Record usage — sync wrapper for async Redis.
Note: This is a sync method because UsageTracker.record() is sync.
For Redis, we use a sync Redis client for writes to avoid
needing an event loop in the caller.
"""
if self._degraded and self._fallback is not None:
self._fallback.record(agent_name, model, usage, cost, latency_ms)
return
try:
r = self._get_sync_redis()
date_key = self._today_key()
hash_key = f"{self.USAGE_PREFIX}{date_key}"
list_key = f"{self.RECORDS_PREFIX}{date_key}"
bucket_field = f"{agent_name}:{model}"
# Atomic HINCRBYFLOAT for bucket aggregation
pipe = r.pipeline()
pipe.hincrbyfloat(hash_key, f"{bucket_field}:cost", cost)
pipe.hincrby(hash_key, f"{bucket_field}:prompt_tokens", usage.prompt_tokens)
pipe.hincrby(hash_key, f"{bucket_field}:completion_tokens", usage.completion_tokens)
pipe.hincrby(hash_key, f"{bucket_field}:total_tokens", usage.total_tokens)
pipe.hincrby(hash_key, f"{bucket_field}:count", 1)
# Append record
rec = UsageRecord(
agent_name=agent_name,
model=model,
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens,
cost=cost,
latency_ms=latency_ms,
)
pipe.rpush(list_key, json.dumps({
"agent_name": rec.agent_name,
"model": rec.model,
"prompt_tokens": rec.prompt_tokens,
"completion_tokens": rec.completion_tokens,
"total_tokens": rec.total_tokens,
"cost": rec.cost,
"latency_ms": rec.latency_ms,
"timestamp": rec.timestamp,
}))
pipe.ltrim(list_key, -self.MAX_RECORDS_PER_DAY, -1)
# Set TTL on first write of the day
pipe.expire(hash_key, self.TTL_DAYS * 86400)
pipe.expire(list_key, self.TTL_DAYS * 86400)
pipe.execute()
except Exception as e:
logger.warning(f"Redis usage record failed: {e}")
self._degrade_to_fallback()
if self._fallback is not None:
self._fallback.record(agent_name, model, usage, cost, latency_ms)
def get_usage(
self,
agent_name: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> UsageSummary:
"""Query usage summary from Redis."""
if self._degraded and self._fallback is not None:
return self._fallback.get_usage(agent_name, start_time, end_time)
try:
r = self._get_sync_redis()
# Determine date range to scan
start = start_time or datetime(2020, 1, 1, tzinfo=timezone.utc)
end = end_time or datetime.now(timezone.utc)
all_records: list[UsageRecord] = []
# Scan date keys in range
current = start.date()
end_date = end.date()
while current <= end_date:
list_key = f"{self.RECORDS_PREFIX}{current.isoformat()}"
raw_records = r.lrange(list_key, 0, -1)
for raw in raw_records:
data = json.loads(raw)
rec = UsageRecord(**data)
rec_ts = datetime.fromisoformat(rec.timestamp)
if rec_ts >= start and rec_ts <= end:
if agent_name is None or rec.agent_name == agent_name:
all_records.append(rec)
current = current + timedelta(days=1)
if not all_records:
return UsageSummary()
total_tokens = sum(r.total_tokens for r in all_records)
total_cost = sum(r.cost for r in all_records)
by_model: dict[str, dict[str, int | float]] = {}
for r in all_records:
if r.model not in by_model:
by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0}
by_model[r.model]["total_tokens"] += r.total_tokens
by_model[r.model]["total_cost"] += r.cost
by_model[r.model]["count"] += 1
return UsageSummary(
total_tokens=total_tokens,
total_cost=total_cost,
by_model=by_model,
records=all_records,
)
except Exception as e:
logger.warning(f"Redis usage query failed: {e}")
if self._fallback is not None:
return self._fallback.get_usage(agent_name, start_time, end_time)
return UsageSummary()
# ---------------------------------------------------------------------------
# Factory
# ---------------------------------------------------------------------------
def create_usage_store(
backend: str = "auto",
redis_url: str = "redis://localhost:6379",
) -> UsageStore:
"""Create a usage store backend.
Args:
backend: "auto" (try Redis, fallback to memory), "redis", "memory".
redis_url: Redis connection URL.
Returns:
A UsageStore instance.
"""
if backend in ("auto", "redis"):
try:
import redis # noqa: F401
return RedisUsageStore(redis_url=redis_url)
except ImportError:
logger.warning("redis package not available, falling back to in-memory usage store")
return InMemoryUsageStore()
return InMemoryUsageStore()