geo/backend/app/services/api_key_manager.py

253 lines
8.3 KiB
Python

import base64
import logging
import os
import uuid
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from sqlalchemy.ext.asyncio import AsyncSession
from app.repositories.api_key_repository import APIKeyRepository
from app.services.key_encryption import KeyEncryption, get_key_encryption
from app.services.key_verifier import KeyStatus, KeyVerifierFactory
logger = logging.getLogger(__name__)
class KeySource(str, Enum):
SYSTEM = "system"
USER = "user"
ENV = "env"
@dataclass
class APIKeyConfig:
engine_type: str
key_source: KeySource
encrypted_key: str
key_hint: str
status: KeyStatus = KeyStatus.UNKNOWN
priority: int = 0
last_verified_at: str | None = None
created_at: str | None = None
user_id: str | None = None
class APIKeyManager:
_ENCRYPTION_KEY = os.getenv("API_KEY_ENCRYPTION_KEY", "geo-platform-default-key-change-in-production")
_USABLE_STATUSES = frozenset({KeyStatus.ACTIVE, KeyStatus.UNKNOWN})
_FALLBACK_SOURCES = frozenset({KeySource.SYSTEM, KeySource.ENV})
_ENV_MAPPING = {
"chatgpt": "OPENAI_API_KEY",
"perplexity": "PERPLEXITY_API_KEY",
"kimi": "MOONSHOT_API_KEY",
"wenxin": "BAIDU_QIANFAN_API_KEY",
"doubao": "DOUBAO_API_KEY",
"deepseek": "DEEPSEEK_API_KEY",
"qwen": "DASHSCOPE_API_KEY",
"gemini": "GOOGLE_API_KEY",
"yuanbao": "HUNYUAN_API_KEY",
}
def __init__(self):
self._keys: dict[str, list[APIKeyConfig]] = {}
def add_key(
self,
engine_type: str,
api_key: str,
source: KeySource = KeySource.SYSTEM,
user_id: str | None = None,
priority: int = 0,
) -> APIKeyConfig:
config = APIKeyConfig(
engine_type=engine_type,
key_source=source,
encrypted_key=self._encrypt(api_key),
key_hint=self._mask_key(api_key),
status=KeyStatus.UNKNOWN,
priority=priority,
user_id=user_id,
)
if engine_type not in self._keys:
self._keys[engine_type] = []
self._keys[engine_type].append(config)
self._keys[engine_type].sort(key=lambda k: k.priority, reverse=True)
return config
def get_key(self, engine_type: str, user_id: str | None = None) -> str | None:
configs = self._keys.get(engine_type, [])
if user_id:
for c in configs:
if (
c.key_source == KeySource.USER
and c.user_id == user_id
and c.status in self._USABLE_STATUSES
):
return self._decrypt(c.encrypted_key)
for c in configs:
if c.key_source in self._FALLBACK_SOURCES and c.status in self._USABLE_STATUSES:
return self._decrypt(c.encrypted_key)
return None
def get_any_available_key(self, engine_type: str) -> str | None:
for c in self._keys.get(engine_type, []):
if c.status in self._USABLE_STATUSES:
return self._decrypt(c.encrypted_key)
return None
def remove_key(self, engine_type: str, key_hint: str) -> bool:
configs = self._keys.get(engine_type, [])
for i, c in enumerate(configs):
if c.key_hint == key_hint:
configs.pop(i)
return True
return False
def list_keys(self, engine_type: str | None = None) -> list[APIKeyConfig]:
if engine_type:
return self._keys.get(engine_type, [])
result = []
for configs in self._keys.values():
result.extend(configs)
return result
async def verify_key(self, engine_type: str, api_key: str) -> KeyStatus:
if not api_key or len(api_key) < 10:
return KeyStatus.INVALID
try:
return await KeyVerifierFactory.verify(engine_type, api_key)
except Exception as e:
logger.warning(f"[api_key_manager] Key verification failed: {e}")
return KeyStatus.UNKNOWN
def _encrypt(self, plaintext: str) -> str:
return get_key_encryption().encrypt(plaintext)
def _decrypt(self, ciphertext: str) -> str:
return get_key_encryption().decrypt(ciphertext)
def _mask_key(self, key: str) -> str:
if len(key) <= 8:
return "***"
return key[:3] + "..." + key[-3:]
def load_env_keys(self):
for engine, env_var in self._ENV_MAPPING.items():
key = os.getenv(env_var, "")
if key:
self.add_key(engine, key, source=KeySource.ENV, priority=0)
async def add_key_async(
self,
session: AsyncSession,
engine_type: str,
api_key: str,
source: KeySource = KeySource.USER,
user_id: uuid.UUID | None = None,
priority: int = 0,
) -> APIKeyConfig:
config = APIKeyConfig(
engine_type=engine_type,
key_source=source,
encrypted_key=self._encrypt(api_key),
key_hint=self._mask_key(api_key),
status=KeyStatus.UNKNOWN,
priority=priority,
user_id=str(user_id) if user_id else None,
)
repository = APIKeyRepository(session)
await repository.create(
user_id=user_id,
engine_type=engine_type,
encrypted_key=config.encrypted_key,
key_hint=config.key_hint,
key_source=source.value,
status=config.status.value,
priority=priority,
)
return config
async def get_key_async(
self,
session: AsyncSession,
engine_type: str,
user_id: uuid.UUID | None = None,
) -> str | None:
repository = APIKeyRepository(session)
keys = await repository.get_by_user(user_id, engine_type)
for key in keys:
source = KeySource(key.key_source)
status = KeyStatus(key.status)
if user_id:
if source == KeySource.USER and status in self._USABLE_STATUSES:
return self._decrypt(key.encrypted_key)
if source in self._FALLBACK_SOURCES and status in self._USABLE_STATUSES:
return self._decrypt(key.encrypted_key)
return None
async def get_any_available_key_async(
self,
session: AsyncSession,
engine_type: str,
) -> str | None:
repository = APIKeyRepository(session)
keys = await repository.get_by_user(None, engine_type)
for key in keys:
status = KeyStatus(key.status)
if status in self._USABLE_STATUSES:
return self._decrypt(key.encrypted_key)
return None
async def remove_key_async(
self,
session: AsyncSession,
user_id: uuid.UUID,
engine_type: str,
key_hint: str,
) -> bool:
repository = APIKeyRepository(session)
keys = await repository.get_by_user(user_id, engine_type)
for key in keys:
if key.key_hint == key_hint:
deleted = await repository.delete(key.id)
return deleted
return False
async def list_keys_async(
self,
session: AsyncSession,
user_id: uuid.UUID | None = None,
engine_type: str | None = None,
) -> list[APIKeyConfig]:
repository = APIKeyRepository(session)
if user_id:
keys = await repository.get_by_user(user_id, engine_type)
else:
keys = await repository.list_all(engine_type=engine_type)
return [
APIKeyConfig(
engine_type=k.engine_type,
key_source=KeySource(k.key_source),
encrypted_key=k.encrypted_key,
key_hint=k.key_hint,
status=KeyStatus(k.status),
priority=k.priority,
last_verified_at=str(k.last_verified_at) if k.last_verified_at else None,
created_at=str(k.created_at) if k.created_at else None,
user_id=str(k.user_id) if k.user_id else None,
)
for k in keys
]
async def update_key_status_async(
self,
session: AsyncSession,
key_id: uuid.UUID,
status: KeyStatus,
) -> bool:
repository = APIKeyRepository(session)
updated = await repository.update(key_id, status=status.value)
return updated is not None