geo/backend/app/repositories/api_key_repository.py

120 lines
3.4 KiB
Python

import uuid
from datetime import datetime
from sqlalchemy import select, and_, delete, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.api_key import APIKey
class APIKeyRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def create(
self,
user_id: uuid.UUID,
engine_type: str,
encrypted_key: str,
key_hint: str,
key_source: str = "user",
status: str = "active",
priority: int = 0,
last_verified_at: datetime | None = None,
) -> APIKey:
api_key = APIKey(
user_id=user_id,
engine_type=engine_type,
encrypted_key=encrypted_key,
key_hint=key_hint,
key_source=key_source,
status=status,
priority=priority,
last_verified_at=last_verified_at,
)
self.session.add(api_key)
await self.session.commit()
await self.session.refresh(api_key)
return api_key
async def get_by_id(self, key_id: uuid.UUID) -> APIKey | None:
result = await self.session.execute(
select(APIKey).where(APIKey.id == key_id)
)
return result.scalar_one_or_none()
async def get_by_user(
self,
user_id: uuid.UUID,
engine_type: str | None = None,
) -> list[APIKey]:
conditions = [APIKey.user_id == user_id]
if engine_type:
conditions.append(APIKey.engine_type == engine_type)
result = await self.session.execute(
select(APIKey).where(and_(*conditions)).order_by(APIKey.priority.desc())
)
return list(result.scalars().all())
async def get_by_user_and_engine(
self,
user_id: uuid.UUID,
engine_type: str,
) -> APIKey | None:
result = await self.session.execute(
select(APIKey).where(
and_(
APIKey.user_id == user_id,
APIKey.engine_type == engine_type,
)
)
)
return result.scalar_one_or_none()
async def update(
self,
key_id: uuid.UUID,
**kwargs,
) -> APIKey | None:
api_key = await self.get_by_id(key_id)
if not api_key:
return None
for key, value in kwargs.items():
if hasattr(api_key, key):
setattr(api_key, key, value)
api_key.updated_at = datetime.utcnow()
await self.session.commit()
await self.session.refresh(api_key)
return api_key
async def delete(self, key_id: uuid.UUID) -> bool:
result = await self.session.execute(
delete(APIKey).where(APIKey.id == key_id)
)
await self.session.commit()
return result.rowcount > 0
async def list_all(
self,
engine_type: str | None = None,
status: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[APIKey]:
conditions = []
if engine_type:
conditions.append(APIKey.engine_type == engine_type)
if status:
conditions.append(APIKey.status == status)
query = select(APIKey)
if conditions:
query = query.where(and_(*conditions))
query = query.order_by(APIKey.priority.desc()).limit(limit).offset(offset)
result = await self.session.execute(query)
return list(result.scalars().all())