253 lines
8.3 KiB
Python
253 lines
8.3 KiB
Python
import base64
|
|
import logging
|
|
import os
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.repositories.api_key_repository import APIKeyRepository
|
|
from app.services.key_encryption import KeyEncryption, get_key_encryption
|
|
from app.services.key_verifier import KeyStatus, KeyVerifierFactory
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class KeySource(str, Enum):
|
|
SYSTEM = "system"
|
|
USER = "user"
|
|
ENV = "env"
|
|
|
|
|
|
@dataclass
|
|
class APIKeyConfig:
|
|
engine_type: str
|
|
key_source: KeySource
|
|
encrypted_key: str
|
|
key_hint: str
|
|
status: KeyStatus = KeyStatus.UNKNOWN
|
|
priority: int = 0
|
|
last_verified_at: str | None = None
|
|
created_at: str | None = None
|
|
user_id: str | None = None
|
|
|
|
|
|
class APIKeyManager:
|
|
_ENCRYPTION_KEY = os.getenv("API_KEY_ENCRYPTION_KEY", "geo-platform-default-key-change-in-production")
|
|
_USABLE_STATUSES = frozenset({KeyStatus.ACTIVE, KeyStatus.UNKNOWN})
|
|
_FALLBACK_SOURCES = frozenset({KeySource.SYSTEM, KeySource.ENV})
|
|
_ENV_MAPPING = {
|
|
"chatgpt": "OPENAI_API_KEY",
|
|
"perplexity": "PERPLEXITY_API_KEY",
|
|
"kimi": "MOONSHOT_API_KEY",
|
|
"wenxin": "BAIDU_QIANFAN_API_KEY",
|
|
"doubao": "DOUBAO_API_KEY",
|
|
"deepseek": "DEEPSEEK_API_KEY",
|
|
"qwen": "DASHSCOPE_API_KEY",
|
|
"gemini": "GOOGLE_API_KEY",
|
|
"yuanbao": "HUNYUAN_API_KEY",
|
|
}
|
|
|
|
def __init__(self):
|
|
self._keys: dict[str, list[APIKeyConfig]] = {}
|
|
|
|
def add_key(
|
|
self,
|
|
engine_type: str,
|
|
api_key: str,
|
|
source: KeySource = KeySource.SYSTEM,
|
|
user_id: str | None = None,
|
|
priority: int = 0,
|
|
) -> APIKeyConfig:
|
|
config = APIKeyConfig(
|
|
engine_type=engine_type,
|
|
key_source=source,
|
|
encrypted_key=self._encrypt(api_key),
|
|
key_hint=self._mask_key(api_key),
|
|
status=KeyStatus.UNKNOWN,
|
|
priority=priority,
|
|
user_id=user_id,
|
|
)
|
|
if engine_type not in self._keys:
|
|
self._keys[engine_type] = []
|
|
self._keys[engine_type].append(config)
|
|
self._keys[engine_type].sort(key=lambda k: k.priority, reverse=True)
|
|
return config
|
|
|
|
def get_key(self, engine_type: str, user_id: str | None = None) -> str | None:
|
|
configs = self._keys.get(engine_type, [])
|
|
if user_id:
|
|
for c in configs:
|
|
if (
|
|
c.key_source == KeySource.USER
|
|
and c.user_id == user_id
|
|
and c.status in self._USABLE_STATUSES
|
|
):
|
|
return self._decrypt(c.encrypted_key)
|
|
for c in configs:
|
|
if c.key_source in self._FALLBACK_SOURCES and c.status in self._USABLE_STATUSES:
|
|
return self._decrypt(c.encrypted_key)
|
|
return None
|
|
|
|
def get_any_available_key(self, engine_type: str) -> str | None:
|
|
for c in self._keys.get(engine_type, []):
|
|
if c.status in self._USABLE_STATUSES:
|
|
return self._decrypt(c.encrypted_key)
|
|
return None
|
|
|
|
def remove_key(self, engine_type: str, key_hint: str) -> bool:
|
|
configs = self._keys.get(engine_type, [])
|
|
for i, c in enumerate(configs):
|
|
if c.key_hint == key_hint:
|
|
configs.pop(i)
|
|
return True
|
|
return False
|
|
|
|
def list_keys(self, engine_type: str | None = None) -> list[APIKeyConfig]:
|
|
if engine_type:
|
|
return self._keys.get(engine_type, [])
|
|
result = []
|
|
for configs in self._keys.values():
|
|
result.extend(configs)
|
|
return result
|
|
|
|
async def verify_key(self, engine_type: str, api_key: str) -> KeyStatus:
|
|
if not api_key or len(api_key) < 10:
|
|
return KeyStatus.INVALID
|
|
try:
|
|
return await KeyVerifierFactory.verify(engine_type, api_key)
|
|
except Exception as e:
|
|
logger.warning(f"[api_key_manager] Key verification failed: {e}")
|
|
return KeyStatus.UNKNOWN
|
|
|
|
def _encrypt(self, plaintext: str) -> str:
|
|
return get_key_encryption().encrypt(plaintext)
|
|
|
|
def _decrypt(self, ciphertext: str) -> str:
|
|
return get_key_encryption().decrypt(ciphertext)
|
|
|
|
def _mask_key(self, key: str) -> str:
|
|
if len(key) <= 8:
|
|
return "***"
|
|
return key[:3] + "..." + key[-3:]
|
|
|
|
def load_env_keys(self):
|
|
for engine, env_var in self._ENV_MAPPING.items():
|
|
key = os.getenv(env_var, "")
|
|
if key:
|
|
self.add_key(engine, key, source=KeySource.ENV, priority=0)
|
|
|
|
async def add_key_async(
|
|
self,
|
|
session: AsyncSession,
|
|
engine_type: str,
|
|
api_key: str,
|
|
source: KeySource = KeySource.USER,
|
|
user_id: uuid.UUID | None = None,
|
|
priority: int = 0,
|
|
) -> APIKeyConfig:
|
|
config = APIKeyConfig(
|
|
engine_type=engine_type,
|
|
key_source=source,
|
|
encrypted_key=self._encrypt(api_key),
|
|
key_hint=self._mask_key(api_key),
|
|
status=KeyStatus.UNKNOWN,
|
|
priority=priority,
|
|
user_id=str(user_id) if user_id else None,
|
|
)
|
|
repository = APIKeyRepository(session)
|
|
await repository.create(
|
|
user_id=user_id,
|
|
engine_type=engine_type,
|
|
encrypted_key=config.encrypted_key,
|
|
key_hint=config.key_hint,
|
|
key_source=source.value,
|
|
status=config.status.value,
|
|
priority=priority,
|
|
)
|
|
return config
|
|
|
|
async def get_key_async(
|
|
self,
|
|
session: AsyncSession,
|
|
engine_type: str,
|
|
user_id: uuid.UUID | None = None,
|
|
) -> str | None:
|
|
repository = APIKeyRepository(session)
|
|
keys = await repository.get_by_user(user_id, engine_type)
|
|
for key in keys:
|
|
source = KeySource(key.key_source)
|
|
status = KeyStatus(key.status)
|
|
if user_id:
|
|
if source == KeySource.USER and status in self._USABLE_STATUSES:
|
|
return self._decrypt(key.encrypted_key)
|
|
if source in self._FALLBACK_SOURCES and status in self._USABLE_STATUSES:
|
|
return self._decrypt(key.encrypted_key)
|
|
return None
|
|
|
|
async def get_any_available_key_async(
|
|
self,
|
|
session: AsyncSession,
|
|
engine_type: str,
|
|
) -> str | None:
|
|
repository = APIKeyRepository(session)
|
|
keys = await repository.get_by_user(None, engine_type)
|
|
for key in keys:
|
|
status = KeyStatus(key.status)
|
|
if status in self._USABLE_STATUSES:
|
|
return self._decrypt(key.encrypted_key)
|
|
return None
|
|
|
|
async def remove_key_async(
|
|
self,
|
|
session: AsyncSession,
|
|
user_id: uuid.UUID,
|
|
engine_type: str,
|
|
key_hint: str,
|
|
) -> bool:
|
|
repository = APIKeyRepository(session)
|
|
keys = await repository.get_by_user(user_id, engine_type)
|
|
for key in keys:
|
|
if key.key_hint == key_hint:
|
|
deleted = await repository.delete(key.id)
|
|
return deleted
|
|
return False
|
|
|
|
async def list_keys_async(
|
|
self,
|
|
session: AsyncSession,
|
|
user_id: uuid.UUID | None = None,
|
|
engine_type: str | None = None,
|
|
) -> list[APIKeyConfig]:
|
|
repository = APIKeyRepository(session)
|
|
if user_id:
|
|
keys = await repository.get_by_user(user_id, engine_type)
|
|
else:
|
|
keys = await repository.list_all(engine_type=engine_type)
|
|
return [
|
|
APIKeyConfig(
|
|
engine_type=k.engine_type,
|
|
key_source=KeySource(k.key_source),
|
|
encrypted_key=k.encrypted_key,
|
|
key_hint=k.key_hint,
|
|
status=KeyStatus(k.status),
|
|
priority=k.priority,
|
|
last_verified_at=str(k.last_verified_at) if k.last_verified_at else None,
|
|
created_at=str(k.created_at) if k.created_at else None,
|
|
user_id=str(k.user_id) if k.user_id else None,
|
|
)
|
|
for k in keys
|
|
]
|
|
|
|
async def update_key_status_async(
|
|
self,
|
|
session: AsyncSession,
|
|
key_id: uuid.UUID,
|
|
status: KeyStatus,
|
|
) -> bool:
|
|
repository = APIKeyRepository(session)
|
|
updated = await repository.update(key_id, status=status.value)
|
|
return updated is not None
|