geo/backend/app/services/llm/rate_limiter.py

110 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM 调用全局限速器(令牌桶算法)
所有 LLMProvider 实例共享同一个 RateLimiter 单例,
确保跨 Provider 的总调用频率不超过配置上限。
"""
import asyncio
import logging
import os
import time
from collections import deque
logger = logging.getLogger(__name__)
class TokenBucketRateLimiter:
"""基于令牌桶算法的速率限制器
默认 30 RPM每秒补充约 0.5 个令牌),可通过环境变量
LLM_RATE_LIMIT_RPM 调整。
"""
_instance: "TokenBucketRateLimiter | None" = None
_lock = asyncio.Lock()
def __init__(
self,
max_rpm: float = 30.0,
):
self._max_rpm = max_rpm
self._refill_rate = max_rpm / 60.0 # tokens per second
self._max_tokens = max_rpm
self._tokens = max_rpm # start full
self._last_refill = time.monotonic()
self._semaphore = asyncio.Semaphore(1) # 一次只能有一个 acquire 等待
# 用于记录最近请求时间(调试/监控用)
self._recent_requests: deque[float] = deque(maxlen=int(max_rpm * 2))
@classmethod
def get_instance(cls) -> "TokenBucketRateLimiter":
"""获取全局单例"""
if cls._instance is None:
rpm = float(os.getenv("LLM_RATE_LIMIT_RPM", "30"))
cls._instance = cls(max_rpm=rpm)
logger.info(f"LLM RateLimiter initialized: {rpm} RPM")
return cls._instance
@classmethod
async def reset_instance(cls) -> None:
"""重置单例(仅用于测试)"""
async with cls._lock:
if cls._instance is not None:
cls._instance = None
def _refill(self) -> None:
"""补充令牌"""
now = time.monotonic()
elapsed = now - self._last_refill
tokens_to_add = elapsed * self._refill_rate
self._tokens = min(self._max_tokens, self._tokens + tokens_to_add)
self._last_refill = now
async def acquire(self) -> None:
"""获取一个令牌,若无可用令牌则等待
此方法可安全地从多个协程并发调用。
"""
async with self._semaphore:
self._refill()
if self._tokens >= 1.0:
self._tokens -= 1.0
self._recent_requests.append(time.monotonic())
return
# 如果没有可用令牌,计算等待时间
wait_time = (1.0 - self._tokens) / self._refill_rate
if wait_time > 0:
logger.debug(f"LLM rate limiter: waiting {wait_time:.2f}s for token")
await asyncio.sleep(wait_time)
# 重试获取
async with self._semaphore:
self._refill()
self._tokens = max(0.0, self._tokens - 1.0)
self._recent_requests.append(time.monotonic())
@property
def available_tokens(self) -> float:
"""当前可用令牌数"""
return self._tokens
@property
def max_rpm(self) -> float:
"""配置的最大 RPM"""
return self._max_rpm
# 模块级便捷函数
_rate_limiter: TokenBucketRateLimiter | None = None
def get_rate_limiter() -> TokenBucketRateLimiter:
"""获取全局速率限制器实例"""
return TokenBucketRateLimiter.get_instance()
# 别名以支持测试验收标准
RateLimiter = TokenBucketRateLimiter