120 lines
3.4 KiB
Python
120 lines
3.4 KiB
Python
import uuid
|
|
from datetime import datetime
|
|
|
|
from sqlalchemy import select, and_, delete, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.api_key import APIKey
|
|
|
|
|
|
class APIKeyRepository:
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
async def create(
|
|
self,
|
|
user_id: uuid.UUID,
|
|
engine_type: str,
|
|
encrypted_key: str,
|
|
key_hint: str,
|
|
key_source: str = "user",
|
|
status: str = "active",
|
|
priority: int = 0,
|
|
last_verified_at: datetime | None = None,
|
|
) -> APIKey:
|
|
api_key = APIKey(
|
|
user_id=user_id,
|
|
engine_type=engine_type,
|
|
encrypted_key=encrypted_key,
|
|
key_hint=key_hint,
|
|
key_source=key_source,
|
|
status=status,
|
|
priority=priority,
|
|
last_verified_at=last_verified_at,
|
|
)
|
|
self.session.add(api_key)
|
|
await self.session.commit()
|
|
await self.session.refresh(api_key)
|
|
return api_key
|
|
|
|
async def get_by_id(self, key_id: uuid.UUID) -> APIKey | None:
|
|
result = await self.session.execute(
|
|
select(APIKey).where(APIKey.id == key_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_by_user(
|
|
self,
|
|
user_id: uuid.UUID,
|
|
engine_type: str | None = None,
|
|
) -> list[APIKey]:
|
|
conditions = [APIKey.user_id == user_id]
|
|
if engine_type:
|
|
conditions.append(APIKey.engine_type == engine_type)
|
|
|
|
result = await self.session.execute(
|
|
select(APIKey).where(and_(*conditions)).order_by(APIKey.priority.desc())
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_by_user_and_engine(
|
|
self,
|
|
user_id: uuid.UUID,
|
|
engine_type: str,
|
|
) -> APIKey | None:
|
|
result = await self.session.execute(
|
|
select(APIKey).where(
|
|
and_(
|
|
APIKey.user_id == user_id,
|
|
APIKey.engine_type == engine_type,
|
|
)
|
|
)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def update(
|
|
self,
|
|
key_id: uuid.UUID,
|
|
**kwargs,
|
|
) -> APIKey | None:
|
|
api_key = await self.get_by_id(key_id)
|
|
if not api_key:
|
|
return None
|
|
|
|
for key, value in kwargs.items():
|
|
if hasattr(api_key, key):
|
|
setattr(api_key, key, value)
|
|
|
|
api_key.updated_at = datetime.utcnow()
|
|
await self.session.commit()
|
|
await self.session.refresh(api_key)
|
|
return api_key
|
|
|
|
async def delete(self, key_id: uuid.UUID) -> bool:
|
|
result = await self.session.execute(
|
|
delete(APIKey).where(APIKey.id == key_id)
|
|
)
|
|
await self.session.commit()
|
|
return result.rowcount > 0
|
|
|
|
async def list_all(
|
|
self,
|
|
engine_type: str | None = None,
|
|
status: str | None = None,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> list[APIKey]:
|
|
conditions = []
|
|
if engine_type:
|
|
conditions.append(APIKey.engine_type == engine_type)
|
|
if status:
|
|
conditions.append(APIKey.status == status)
|
|
|
|
query = select(APIKey)
|
|
if conditions:
|
|
query = query.where(and_(*conditions))
|
|
query = query.order_by(APIKey.priority.desc()).limit(limit).offset(offset)
|
|
|
|
result = await self.session.execute(query)
|
|
return list(result.scalars().all())
|