geo/backend/app/services/api_key_manager.py

307 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import base64
import json
import logging
import os
import uuid
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from sqlalchemy.ext.asyncio import AsyncSession
from app.repositories.api_key_repository import APIKeyRepository
from app.services.key_encryption import KeyEncryption, get_key_encryption
from app.services.key_verifier import KeyStatus, KeyVerifierFactory
logger = logging.getLogger(__name__)
class KeySource(str, Enum):
SYSTEM = "system"
USER = "user"
ENV = "env"
@dataclass
class KeyCredentials:
"""统一凭证格式支持单Key和双Key"""
api_key: str
secret_key: str | None = None
def to_dict(self) -> dict:
if self.secret_key:
return {"api_key": self.api_key, "secret_key": self.secret_key}
return {"api_key": self.api_key}
def to_json(self) -> str:
return json.dumps(self.to_dict())
@dataclass
class APIKeyConfig:
engine_type: str
key_source: KeySource
encrypted_key: str
key_hint: str
status: KeyStatus = KeyStatus.UNKNOWN
priority: int = 0
last_verified_at: str | None = None
created_at: str | None = None
user_id: str | None = None
class APIKeyManager:
_ENCRYPTION_KEY = os.getenv("API_KEY_ENCRYPTION_KEY", "geo-platform-default-key-change-in-production")
_USABLE_STATUSES = frozenset({KeyStatus.ACTIVE, KeyStatus.UNKNOWN})
_FALLBACK_SOURCES = frozenset({KeySource.SYSTEM, KeySource.ENV})
_ENV_MAPPING = {
"chatgpt": "OPENAI_API_KEY",
"perplexity": "PERPLEXITY_API_KEY",
"kimi": "MOONSHOT_API_KEY",
"wenxin": "BAIDU_QIANFAN_API_KEY",
"doubao": "DOUBAO_API_KEY",
"deepseek": "DEEPSEEK_API_KEY",
"qwen": "DASHSCOPE_API_KEY",
"gemini": "GOOGLE_API_KEY",
"yuanbao": "HUNYUAN_API_KEY",
}
def __init__(self):
self._keys: dict[str, list[APIKeyConfig]] = {}
def add_key(
self,
engine_type: str,
credentials: str | dict,
source: KeySource = KeySource.SYSTEM,
user_id: str | None = None,
priority: int = 0,
) -> APIKeyConfig:
if isinstance(credentials, dict):
key_hint = self._create_dual_key_hint(credentials)
credentials_json = json.dumps(credentials)
encrypted_key = self._encrypt(credentials_json)
else:
key_hint = self._mask_key(credentials)
encrypted_key = self._encrypt(credentials)
config = APIKeyConfig(
engine_type=engine_type,
key_source=source,
encrypted_key=encrypted_key,
key_hint=key_hint,
status=KeyStatus.UNKNOWN,
priority=priority,
user_id=user_id,
)
if engine_type not in self._keys:
self._keys[engine_type] = []
self._keys[engine_type].append(config)
self._keys[engine_type].sort(key=lambda k: k.priority, reverse=True)
return config
def _create_dual_key_hint(self, credentials: dict) -> str:
api_key = credentials.get("api_key", "")
secret_key = credentials.get("secret_key", "")
api_hint = self._mask_key(api_key) if api_key else "***"
secret_hint = self._mask_key(secret_key) if secret_key else "***"
return f"{api_hint}|{secret_hint}"
def get_key(self, engine_type: str, user_id: str | None = None) -> str | dict | None:
configs = self._keys.get(engine_type, [])
config = self._find_best_key(configs, user_id)
if not config:
return None
decrypted = self._decrypt(config.encrypted_key)
try:
return json.loads(decrypted)
except (json.JSONDecodeError, TypeError):
return decrypted
def get_credentials(self, engine_type: str, user_id: str | None = None) -> KeyCredentials | None:
key_data = self.get_key(engine_type, user_id)
if not key_data:
return None
if isinstance(key_data, dict):
return KeyCredentials(
api_key=key_data.get("api_key", ""),
secret_key=key_data.get("secret_key")
)
return KeyCredentials(api_key=key_data)
def get_all_keys(self, engine_type: str, user_id: str | None = None) -> dict | str | None:
return self.get_key(engine_type, user_id)
def _find_best_key(self, configs: list[APIKeyConfig], user_id: str | None = None) -> APIKeyConfig | None:
if user_id:
for c in configs:
if (
c.key_source == KeySource.USER
and c.user_id == user_id
and c.status in self._USABLE_STATUSES
):
return c
for c in configs:
if c.key_source in self._FALLBACK_SOURCES and c.status in self._USABLE_STATUSES:
return c
return None
def get_any_available_key(self, engine_type: str) -> str | None:
for c in self._keys.get(engine_type, []):
if c.status in self._USABLE_STATUSES:
return self._decrypt(c.encrypted_key)
return None
def remove_key(self, engine_type: str, key_hint: str) -> bool:
configs = self._keys.get(engine_type, [])
for i, c in enumerate(configs):
if c.key_hint == key_hint:
configs.pop(i)
return True
return False
def list_keys(self, engine_type: str | None = None) -> list[APIKeyConfig]:
if engine_type:
return self._keys.get(engine_type, [])
result = []
for configs in self._keys.values():
result.extend(configs)
return result
async def verify_key(self, engine_type: str, api_key: str) -> KeyStatus:
if not api_key or len(api_key) < 10:
return KeyStatus.INVALID
try:
return await KeyVerifierFactory.verify(engine_type, api_key)
except Exception as e:
logger.warning(f"[api_key_manager] Key verification failed: {e}")
return KeyStatus.UNKNOWN
def _encrypt(self, plaintext: str) -> str:
return get_key_encryption().encrypt(plaintext)
def _decrypt(self, ciphertext: str) -> str:
return get_key_encryption().decrypt(ciphertext)
def _mask_key(self, key: str) -> str:
if len(key) <= 8:
return "***"
return key[:3] + "..." + key[-3:]
def load_env_keys(self):
for engine, env_var in self._ENV_MAPPING.items():
key = os.getenv(env_var, "")
if key:
self.add_key(engine, key, source=KeySource.ENV, priority=0)
async def add_key_async(
self,
session: AsyncSession,
engine_type: str,
api_key: str,
source: KeySource = KeySource.USER,
user_id: uuid.UUID | None = None,
priority: int = 0,
) -> APIKeyConfig:
config = APIKeyConfig(
engine_type=engine_type,
key_source=source,
encrypted_key=self._encrypt(api_key),
key_hint=self._mask_key(api_key),
status=KeyStatus.UNKNOWN,
priority=priority,
user_id=str(user_id) if user_id else None,
)
repository = APIKeyRepository(session)
await repository.create(
user_id=user_id,
engine_type=engine_type,
encrypted_key=config.encrypted_key,
key_hint=config.key_hint,
key_source=source.value,
status=config.status.value,
priority=priority,
)
return config
async def get_key_async(
self,
session: AsyncSession,
engine_type: str,
user_id: uuid.UUID | None = None,
) -> str | None:
repository = APIKeyRepository(session)
keys = await repository.get_by_user(user_id, engine_type)
for key in keys:
source = KeySource(key.key_source)
status = KeyStatus(key.status)
if user_id:
if source == KeySource.USER and status in self._USABLE_STATUSES:
return self._decrypt(key.encrypted_key)
if source in self._FALLBACK_SOURCES and status in self._USABLE_STATUSES:
return self._decrypt(key.encrypted_key)
return None
async def get_any_available_key_async(
self,
session: AsyncSession,
engine_type: str,
) -> str | None:
repository = APIKeyRepository(session)
keys = await repository.get_by_user(None, engine_type)
for key in keys:
status = KeyStatus(key.status)
if status in self._USABLE_STATUSES:
return self._decrypt(key.encrypted_key)
return None
async def remove_key_async(
self,
session: AsyncSession,
user_id: uuid.UUID,
engine_type: str,
key_hint: str,
) -> bool:
repository = APIKeyRepository(session)
keys = await repository.get_by_user(user_id, engine_type)
for key in keys:
if key.key_hint == key_hint:
deleted = await repository.delete(key.id)
return deleted
return False
async def list_keys_async(
self,
session: AsyncSession,
user_id: uuid.UUID | None = None,
engine_type: str | None = None,
) -> list[APIKeyConfig]:
repository = APIKeyRepository(session)
if user_id:
keys = await repository.get_by_user(user_id, engine_type)
else:
keys = await repository.list_all(engine_type=engine_type)
return [
APIKeyConfig(
engine_type=k.engine_type,
key_source=KeySource(k.key_source),
encrypted_key=k.encrypted_key,
key_hint=k.key_hint,
status=KeyStatus(k.status),
priority=k.priority,
last_verified_at=str(k.last_verified_at) if k.last_verified_at else None,
created_at=str(k.created_at) if k.created_at else None,
user_id=str(k.user_id) if k.user_id else None,
)
for k in keys
]
async def update_key_status_async(
self,
session: AsyncSession,
key_id: uuid.UUID,
status: KeyStatus,
) -> bool:
repository = APIKeyRepository(session)
updated = await repository.update(key_id, status=status.value)
return updated is not None