geo/backend/app/services/api_key_manager.py

131 lines
4.0 KiB
Python

import base64
import logging
import os
from dataclasses import dataclass
from enum import Enum
logger = logging.getLogger(__name__)
class KeySource(str, Enum):
SYSTEM = "system"
USER = "user"
ENV = "env"
class KeyStatus(str, Enum):
ACTIVE = "active"
INVALID = "invalid"
EXPIRED = "expired"
RATE_LIMITED = "rate_limited"
UNKNOWN = "unknown"
@dataclass
class APIKeyConfig:
engine_type: str
key_source: KeySource
encrypted_key: str
key_hint: str
status: KeyStatus = KeyStatus.UNKNOWN
priority: int = 0
last_verified_at: str | None = None
created_at: str | None = None
user_id: str | None = None
class APIKeyManager:
_ENCRYPTION_KEY = os.getenv("API_KEY_ENCRYPTION_KEY", "geo-platform-default-key-change-in-production")
_USABLE_STATUSES = frozenset({KeyStatus.ACTIVE, KeyStatus.UNKNOWN})
_FALLBACK_SOURCES = frozenset({KeySource.SYSTEM, KeySource.ENV})
_ENV_MAPPING = {
"chatgpt": "OPENAI_API_KEY",
"perplexity": "PERPLEXITY_API_KEY",
"kimi": "MOONSHOT_API_KEY",
"wenxin": "BAIDU_QIANFAN_API_KEY",
"doubao": "DOUBAO_API_KEY",
"deepseek": "DEEPSEEK_API_KEY",
"qwen": "DASHSCOPE_API_KEY",
"gemini": "GOOGLE_API_KEY",
"yuanbao": "HUNYUAN_API_KEY",
}
def __init__(self):
self._keys: dict[str, list[APIKeyConfig]] = {}
def add_key(
self,
engine_type: str,
api_key: str,
source: KeySource = KeySource.SYSTEM,
user_id: str | None = None,
priority: int = 0,
) -> APIKeyConfig:
config = APIKeyConfig(
engine_type=engine_type,
key_source=source,
encrypted_key=self._encrypt(api_key),
key_hint=self._mask_key(api_key),
status=KeyStatus.UNKNOWN,
priority=priority,
user_id=user_id,
)
if engine_type not in self._keys:
self._keys[engine_type] = []
self._keys[engine_type].append(config)
self._keys[engine_type].sort(key=lambda k: k.priority, reverse=True)
return config
def get_key(self, engine_type: str, user_id: str | None = None) -> str | None:
configs = self._keys.get(engine_type, [])
if user_id:
for c in configs:
if (
c.key_source == KeySource.USER
and c.user_id == user_id
and c.status in self._USABLE_STATUSES
):
return self._decrypt(c.encrypted_key)
for c in configs:
if c.key_source in self._FALLBACK_SOURCES and c.status in self._USABLE_STATUSES:
return self._decrypt(c.encrypted_key)
return None
def remove_key(self, engine_type: str, key_hint: str) -> bool:
configs = self._keys.get(engine_type, [])
for i, c in enumerate(configs):
if c.key_hint == key_hint:
configs.pop(i)
return True
return False
def list_keys(self, engine_type: str | None = None) -> list[APIKeyConfig]:
if engine_type:
return self._keys.get(engine_type, [])
result = []
for configs in self._keys.values():
result.extend(configs)
return result
async def verify_key(self, engine_type: str, api_key: str) -> KeyStatus:
if not api_key or len(api_key) < 10:
return KeyStatus.INVALID
return KeyStatus.ACTIVE
def _encrypt(self, plaintext: str) -> str:
return base64.b64encode(plaintext.encode()).decode()
def _decrypt(self, ciphertext: str) -> str:
return base64.b64decode(ciphertext.encode()).decode()
def _mask_key(self, key: str) -> str:
if len(key) <= 8:
return "***"
return key[:3] + "..." + key[-3:]
def load_env_keys(self):
for engine, env_var in self._ENV_MAPPING.items():
key = os.getenv(env_var, "")
if key:
self.add_key(engine, key, source=KeySource.ENV, priority=0)