307 lines
10 KiB
Python
307 lines
10 KiB
Python
import base64
|
||
import json
|
||
import logging
|
||
import os
|
||
import uuid
|
||
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
from enum import Enum
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.repositories.api_key_repository import APIKeyRepository
|
||
from app.services.key_encryption import KeyEncryption, get_key_encryption
|
||
from app.services.key_verifier import KeyStatus, KeyVerifierFactory
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class KeySource(str, Enum):
|
||
SYSTEM = "system"
|
||
USER = "user"
|
||
ENV = "env"
|
||
|
||
|
||
@dataclass
|
||
class KeyCredentials:
|
||
"""统一凭证格式,支持单Key和双Key"""
|
||
api_key: str
|
||
secret_key: str | None = None
|
||
|
||
def to_dict(self) -> dict:
|
||
if self.secret_key:
|
||
return {"api_key": self.api_key, "secret_key": self.secret_key}
|
||
return {"api_key": self.api_key}
|
||
|
||
def to_json(self) -> str:
|
||
return json.dumps(self.to_dict())
|
||
|
||
|
||
@dataclass
|
||
class APIKeyConfig:
|
||
engine_type: str
|
||
key_source: KeySource
|
||
encrypted_key: str
|
||
key_hint: str
|
||
status: KeyStatus = KeyStatus.UNKNOWN
|
||
priority: int = 0
|
||
last_verified_at: str | None = None
|
||
created_at: str | None = None
|
||
user_id: str | None = None
|
||
|
||
|
||
class APIKeyManager:
|
||
_ENCRYPTION_KEY = os.getenv("API_KEY_ENCRYPTION_KEY", "geo-platform-default-key-change-in-production")
|
||
_USABLE_STATUSES = frozenset({KeyStatus.ACTIVE, KeyStatus.UNKNOWN})
|
||
_FALLBACK_SOURCES = frozenset({KeySource.SYSTEM, KeySource.ENV})
|
||
_ENV_MAPPING = {
|
||
"chatgpt": "OPENAI_API_KEY",
|
||
"perplexity": "PERPLEXITY_API_KEY",
|
||
"kimi": "MOONSHOT_API_KEY",
|
||
"wenxin": "BAIDU_QIANFAN_API_KEY",
|
||
"doubao": "DOUBAO_API_KEY",
|
||
"deepseek": "DEEPSEEK_API_KEY",
|
||
"qwen": "DASHSCOPE_API_KEY",
|
||
"gemini": "GOOGLE_API_KEY",
|
||
"yuanbao": "HUNYUAN_API_KEY",
|
||
}
|
||
|
||
def __init__(self):
|
||
self._keys: dict[str, list[APIKeyConfig]] = {}
|
||
|
||
def add_key(
|
||
self,
|
||
engine_type: str,
|
||
credentials: str | dict,
|
||
source: KeySource = KeySource.SYSTEM,
|
||
user_id: str | None = None,
|
||
priority: int = 0,
|
||
) -> APIKeyConfig:
|
||
if isinstance(credentials, dict):
|
||
key_hint = self._create_dual_key_hint(credentials)
|
||
credentials_json = json.dumps(credentials)
|
||
encrypted_key = self._encrypt(credentials_json)
|
||
else:
|
||
key_hint = self._mask_key(credentials)
|
||
encrypted_key = self._encrypt(credentials)
|
||
config = APIKeyConfig(
|
||
engine_type=engine_type,
|
||
key_source=source,
|
||
encrypted_key=encrypted_key,
|
||
key_hint=key_hint,
|
||
status=KeyStatus.UNKNOWN,
|
||
priority=priority,
|
||
user_id=user_id,
|
||
)
|
||
if engine_type not in self._keys:
|
||
self._keys[engine_type] = []
|
||
self._keys[engine_type].append(config)
|
||
self._keys[engine_type].sort(key=lambda k: k.priority, reverse=True)
|
||
return config
|
||
|
||
def _create_dual_key_hint(self, credentials: dict) -> str:
|
||
api_key = credentials.get("api_key", "")
|
||
secret_key = credentials.get("secret_key", "")
|
||
api_hint = self._mask_key(api_key) if api_key else "***"
|
||
secret_hint = self._mask_key(secret_key) if secret_key else "***"
|
||
return f"{api_hint}|{secret_hint}"
|
||
|
||
def get_key(self, engine_type: str, user_id: str | None = None) -> str | dict | None:
|
||
configs = self._keys.get(engine_type, [])
|
||
config = self._find_best_key(configs, user_id)
|
||
if not config:
|
||
return None
|
||
decrypted = self._decrypt(config.encrypted_key)
|
||
try:
|
||
return json.loads(decrypted)
|
||
except (json.JSONDecodeError, TypeError):
|
||
return decrypted
|
||
|
||
def get_credentials(self, engine_type: str, user_id: str | None = None) -> KeyCredentials | None:
|
||
key_data = self.get_key(engine_type, user_id)
|
||
if not key_data:
|
||
return None
|
||
if isinstance(key_data, dict):
|
||
return KeyCredentials(
|
||
api_key=key_data.get("api_key", ""),
|
||
secret_key=key_data.get("secret_key")
|
||
)
|
||
return KeyCredentials(api_key=key_data)
|
||
|
||
def get_all_keys(self, engine_type: str, user_id: str | None = None) -> dict | str | None:
|
||
return self.get_key(engine_type, user_id)
|
||
|
||
def _find_best_key(self, configs: list[APIKeyConfig], user_id: str | None = None) -> APIKeyConfig | None:
|
||
if user_id:
|
||
for c in configs:
|
||
if (
|
||
c.key_source == KeySource.USER
|
||
and c.user_id == user_id
|
||
and c.status in self._USABLE_STATUSES
|
||
):
|
||
return c
|
||
for c in configs:
|
||
if c.key_source in self._FALLBACK_SOURCES and c.status in self._USABLE_STATUSES:
|
||
return c
|
||
return None
|
||
|
||
def get_any_available_key(self, engine_type: str) -> str | None:
|
||
for c in self._keys.get(engine_type, []):
|
||
if c.status in self._USABLE_STATUSES:
|
||
return self._decrypt(c.encrypted_key)
|
||
return None
|
||
|
||
def remove_key(self, engine_type: str, key_hint: str) -> bool:
|
||
configs = self._keys.get(engine_type, [])
|
||
for i, c in enumerate(configs):
|
||
if c.key_hint == key_hint:
|
||
configs.pop(i)
|
||
return True
|
||
return False
|
||
|
||
def list_keys(self, engine_type: str | None = None) -> list[APIKeyConfig]:
|
||
if engine_type:
|
||
return self._keys.get(engine_type, [])
|
||
result = []
|
||
for configs in self._keys.values():
|
||
result.extend(configs)
|
||
return result
|
||
|
||
async def verify_key(self, engine_type: str, api_key: str) -> KeyStatus:
|
||
if not api_key or len(api_key) < 10:
|
||
return KeyStatus.INVALID
|
||
try:
|
||
return await KeyVerifierFactory.verify(engine_type, api_key)
|
||
except Exception as e:
|
||
logger.warning(f"[api_key_manager] Key verification failed: {e}")
|
||
return KeyStatus.UNKNOWN
|
||
|
||
def _encrypt(self, plaintext: str) -> str:
|
||
return get_key_encryption().encrypt(plaintext)
|
||
|
||
def _decrypt(self, ciphertext: str) -> str:
|
||
return get_key_encryption().decrypt(ciphertext)
|
||
|
||
def _mask_key(self, key: str) -> str:
|
||
if len(key) <= 8:
|
||
return "***"
|
||
return key[:3] + "..." + key[-3:]
|
||
|
||
def load_env_keys(self):
|
||
for engine, env_var in self._ENV_MAPPING.items():
|
||
key = os.getenv(env_var, "")
|
||
if key:
|
||
self.add_key(engine, key, source=KeySource.ENV, priority=0)
|
||
|
||
async def add_key_async(
|
||
self,
|
||
session: AsyncSession,
|
||
engine_type: str,
|
||
api_key: str,
|
||
source: KeySource = KeySource.USER,
|
||
user_id: uuid.UUID | None = None,
|
||
priority: int = 0,
|
||
) -> APIKeyConfig:
|
||
config = APIKeyConfig(
|
||
engine_type=engine_type,
|
||
key_source=source,
|
||
encrypted_key=self._encrypt(api_key),
|
||
key_hint=self._mask_key(api_key),
|
||
status=KeyStatus.UNKNOWN,
|
||
priority=priority,
|
||
user_id=str(user_id) if user_id else None,
|
||
)
|
||
repository = APIKeyRepository(session)
|
||
await repository.create(
|
||
user_id=user_id,
|
||
engine_type=engine_type,
|
||
encrypted_key=config.encrypted_key,
|
||
key_hint=config.key_hint,
|
||
key_source=source.value,
|
||
status=config.status.value,
|
||
priority=priority,
|
||
)
|
||
return config
|
||
|
||
async def get_key_async(
|
||
self,
|
||
session: AsyncSession,
|
||
engine_type: str,
|
||
user_id: uuid.UUID | None = None,
|
||
) -> str | None:
|
||
repository = APIKeyRepository(session)
|
||
keys = await repository.get_by_user(user_id, engine_type)
|
||
for key in keys:
|
||
source = KeySource(key.key_source)
|
||
status = KeyStatus(key.status)
|
||
if user_id:
|
||
if source == KeySource.USER and status in self._USABLE_STATUSES:
|
||
return self._decrypt(key.encrypted_key)
|
||
if source in self._FALLBACK_SOURCES and status in self._USABLE_STATUSES:
|
||
return self._decrypt(key.encrypted_key)
|
||
return None
|
||
|
||
async def get_any_available_key_async(
|
||
self,
|
||
session: AsyncSession,
|
||
engine_type: str,
|
||
) -> str | None:
|
||
repository = APIKeyRepository(session)
|
||
keys = await repository.get_by_user(None, engine_type)
|
||
for key in keys:
|
||
status = KeyStatus(key.status)
|
||
if status in self._USABLE_STATUSES:
|
||
return self._decrypt(key.encrypted_key)
|
||
return None
|
||
|
||
async def remove_key_async(
|
||
self,
|
||
session: AsyncSession,
|
||
user_id: uuid.UUID,
|
||
engine_type: str,
|
||
key_hint: str,
|
||
) -> bool:
|
||
repository = APIKeyRepository(session)
|
||
keys = await repository.get_by_user(user_id, engine_type)
|
||
for key in keys:
|
||
if key.key_hint == key_hint:
|
||
deleted = await repository.delete(key.id)
|
||
return deleted
|
||
return False
|
||
|
||
async def list_keys_async(
|
||
self,
|
||
session: AsyncSession,
|
||
user_id: uuid.UUID | None = None,
|
||
engine_type: str | None = None,
|
||
) -> list[APIKeyConfig]:
|
||
repository = APIKeyRepository(session)
|
||
if user_id:
|
||
keys = await repository.get_by_user(user_id, engine_type)
|
||
else:
|
||
keys = await repository.list_all(engine_type=engine_type)
|
||
return [
|
||
APIKeyConfig(
|
||
engine_type=k.engine_type,
|
||
key_source=KeySource(k.key_source),
|
||
encrypted_key=k.encrypted_key,
|
||
key_hint=k.key_hint,
|
||
status=KeyStatus(k.status),
|
||
priority=k.priority,
|
||
last_verified_at=str(k.last_verified_at) if k.last_verified_at else None,
|
||
created_at=str(k.created_at) if k.created_at else None,
|
||
user_id=str(k.user_id) if k.user_id else None,
|
||
)
|
||
for k in keys
|
||
]
|
||
|
||
async def update_key_status_async(
|
||
self,
|
||
session: AsyncSession,
|
||
key_id: uuid.UUID,
|
||
status: KeyStatus,
|
||
) -> bool:
|
||
repository = APIKeyRepository(session)
|
||
updated = await repository.update(key_id, status=status.value)
|
||
return updated is not None
|