106 lines
3.2 KiB
Python
106 lines
3.2 KiB
Python
"""LLM 调用全局限速器(令牌桶算法)
|
||
|
||
所有 LLMProvider 实例共享同一个 RateLimiter 单例,
|
||
确保跨 Provider 的总调用频率不超过配置上限。
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
import time
|
||
from collections import deque
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class TokenBucketRateLimiter:
|
||
"""基于令牌桶算法的速率限制器
|
||
|
||
默认 30 RPM(每秒补充约 0.5 个令牌),可通过环境变量
|
||
LLM_RATE_LIMIT_RPM 调整。
|
||
"""
|
||
|
||
_instance: "TokenBucketRateLimiter | None" = None
|
||
_lock = asyncio.Lock()
|
||
|
||
def __init__(
|
||
self,
|
||
max_rpm: float = 30.0,
|
||
):
|
||
self._max_rpm = max_rpm
|
||
self._refill_rate = max_rpm / 60.0 # tokens per second
|
||
self._max_tokens = max_rpm
|
||
self._tokens = max_rpm # start full
|
||
self._last_refill = time.monotonic()
|
||
self._semaphore = asyncio.Semaphore(1) # 一次只能有一个 acquire 等待
|
||
# 用于记录最近请求时间(调试/监控用)
|
||
self._recent_requests: deque[float] = deque(maxlen=int(max_rpm * 2))
|
||
|
||
@classmethod
|
||
def get_instance(cls) -> "TokenBucketRateLimiter":
|
||
"""获取全局单例"""
|
||
if cls._instance is None:
|
||
rpm = float(os.getenv("LLM_RATE_LIMIT_RPM", "30"))
|
||
cls._instance = cls(max_rpm=rpm)
|
||
logger.info(f"LLM RateLimiter initialized: {rpm} RPM")
|
||
return cls._instance
|
||
|
||
@classmethod
|
||
async def reset_instance(cls) -> None:
|
||
"""重置单例(仅用于测试)"""
|
||
async with cls._lock:
|
||
if cls._instance is not None:
|
||
cls._instance = None
|
||
|
||
def _refill(self) -> None:
|
||
"""补充令牌"""
|
||
now = time.monotonic()
|
||
elapsed = now - self._last_refill
|
||
tokens_to_add = elapsed * self._refill_rate
|
||
self._tokens = min(self._max_tokens, self._tokens + tokens_to_add)
|
||
self._last_refill = now
|
||
|
||
async def acquire(self) -> None:
|
||
"""获取一个令牌,若无可用令牌则等待
|
||
|
||
此方法可安全地从多个协程并发调用。
|
||
"""
|
||
async with self._semaphore:
|
||
self._refill()
|
||
|
||
if self._tokens >= 1.0:
|
||
self._tokens -= 1.0
|
||
self._recent_requests.append(time.monotonic())
|
||
return
|
||
|
||
# 如果没有可用令牌,计算等待时间
|
||
wait_time = (1.0 - self._tokens) / self._refill_rate
|
||
if wait_time > 0:
|
||
logger.debug(f"LLM rate limiter: waiting {wait_time:.2f}s for token")
|
||
await asyncio.sleep(wait_time)
|
||
|
||
# 重试获取
|
||
async with self._semaphore:
|
||
self._refill()
|
||
self._tokens = max(0.0, self._tokens - 1.0)
|
||
self._recent_requests.append(time.monotonic())
|
||
|
||
@property
|
||
def available_tokens(self) -> float:
|
||
"""当前可用令牌数"""
|
||
return self._tokens
|
||
|
||
@property
|
||
def max_rpm(self) -> float:
|
||
"""配置的最大 RPM"""
|
||
return self._max_rpm
|
||
|
||
|
||
# 模块级便捷函数
|
||
_rate_limiter: TokenBucketRateLimiter | None = None
|
||
|
||
|
||
def get_rate_limiter() -> TokenBucketRateLimiter:
|
||
"""获取全局速率限制器实例"""
|
||
return TokenBucketRateLimiter.get_instance()
|