131 lines
4.0 KiB
Python
131 lines
4.0 KiB
Python
import base64
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class KeySource(str, Enum):
|
|
SYSTEM = "system"
|
|
USER = "user"
|
|
ENV = "env"
|
|
|
|
|
|
class KeyStatus(str, Enum):
|
|
ACTIVE = "active"
|
|
INVALID = "invalid"
|
|
EXPIRED = "expired"
|
|
RATE_LIMITED = "rate_limited"
|
|
UNKNOWN = "unknown"
|
|
|
|
|
|
@dataclass
|
|
class APIKeyConfig:
|
|
engine_type: str
|
|
key_source: KeySource
|
|
encrypted_key: str
|
|
key_hint: str
|
|
status: KeyStatus = KeyStatus.UNKNOWN
|
|
priority: int = 0
|
|
last_verified_at: str | None = None
|
|
created_at: str | None = None
|
|
user_id: str | None = None
|
|
|
|
|
|
class APIKeyManager:
|
|
_ENCRYPTION_KEY = os.getenv("API_KEY_ENCRYPTION_KEY", "geo-platform-default-key-change-in-production")
|
|
_USABLE_STATUSES = frozenset({KeyStatus.ACTIVE, KeyStatus.UNKNOWN})
|
|
_FALLBACK_SOURCES = frozenset({KeySource.SYSTEM, KeySource.ENV})
|
|
_ENV_MAPPING = {
|
|
"chatgpt": "OPENAI_API_KEY",
|
|
"perplexity": "PERPLEXITY_API_KEY",
|
|
"kimi": "MOONSHOT_API_KEY",
|
|
"wenxin": "BAIDU_QIANFAN_API_KEY",
|
|
"doubao": "DOUBAO_API_KEY",
|
|
"deepseek": "DEEPSEEK_API_KEY",
|
|
"qwen": "DASHSCOPE_API_KEY",
|
|
"gemini": "GOOGLE_API_KEY",
|
|
"yuanbao": "HUNYUAN_API_KEY",
|
|
}
|
|
|
|
def __init__(self):
|
|
self._keys: dict[str, list[APIKeyConfig]] = {}
|
|
|
|
def add_key(
|
|
self,
|
|
engine_type: str,
|
|
api_key: str,
|
|
source: KeySource = KeySource.SYSTEM,
|
|
user_id: str | None = None,
|
|
priority: int = 0,
|
|
) -> APIKeyConfig:
|
|
config = APIKeyConfig(
|
|
engine_type=engine_type,
|
|
key_source=source,
|
|
encrypted_key=self._encrypt(api_key),
|
|
key_hint=self._mask_key(api_key),
|
|
status=KeyStatus.UNKNOWN,
|
|
priority=priority,
|
|
user_id=user_id,
|
|
)
|
|
if engine_type not in self._keys:
|
|
self._keys[engine_type] = []
|
|
self._keys[engine_type].append(config)
|
|
self._keys[engine_type].sort(key=lambda k: k.priority, reverse=True)
|
|
return config
|
|
|
|
def get_key(self, engine_type: str, user_id: str | None = None) -> str | None:
|
|
configs = self._keys.get(engine_type, [])
|
|
if user_id:
|
|
for c in configs:
|
|
if (
|
|
c.key_source == KeySource.USER
|
|
and c.user_id == user_id
|
|
and c.status in self._USABLE_STATUSES
|
|
):
|
|
return self._decrypt(c.encrypted_key)
|
|
for c in configs:
|
|
if c.key_source in self._FALLBACK_SOURCES and c.status in self._USABLE_STATUSES:
|
|
return self._decrypt(c.encrypted_key)
|
|
return None
|
|
|
|
def remove_key(self, engine_type: str, key_hint: str) -> bool:
|
|
configs = self._keys.get(engine_type, [])
|
|
for i, c in enumerate(configs):
|
|
if c.key_hint == key_hint:
|
|
configs.pop(i)
|
|
return True
|
|
return False
|
|
|
|
def list_keys(self, engine_type: str | None = None) -> list[APIKeyConfig]:
|
|
if engine_type:
|
|
return self._keys.get(engine_type, [])
|
|
result = []
|
|
for configs in self._keys.values():
|
|
result.extend(configs)
|
|
return result
|
|
|
|
async def verify_key(self, engine_type: str, api_key: str) -> KeyStatus:
|
|
if not api_key or len(api_key) < 10:
|
|
return KeyStatus.INVALID
|
|
return KeyStatus.ACTIVE
|
|
|
|
def _encrypt(self, plaintext: str) -> str:
|
|
return base64.b64encode(plaintext.encode()).decode()
|
|
|
|
def _decrypt(self, ciphertext: str) -> str:
|
|
return base64.b64decode(ciphertext.encode()).decode()
|
|
|
|
def _mask_key(self, key: str) -> str:
|
|
if len(key) <= 8:
|
|
return "***"
|
|
return key[:3] + "..." + key[-3:]
|
|
|
|
def load_env_keys(self):
|
|
for engine, env_var in self._ENV_MAPPING.items():
|
|
key = os.getenv(env_var, "")
|
|
if key:
|
|
self.add_key(engine, key, source=KeySource.ENV, priority=0)
|