85 lines
2.7 KiB
Python
85 lines
2.7 KiB
Python
import uuid
|
|
from typing import Optional
|
|
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.citation_record import CitationRecord
|
|
|
|
|
|
class CitationRepository:
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
async def get_by_id(self, id: uuid.UUID) -> Optional[CitationRecord]:
|
|
result = await self.session.execute(
|
|
select(CitationRecord).where(CitationRecord.id == id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def list_by_query(
|
|
self, query_id: uuid.UUID, *, skip: int = 0, limit: int = 100
|
|
) -> list[CitationRecord]:
|
|
result = await self.session.execute(
|
|
select(CitationRecord)
|
|
.where(CitationRecord.query_id == query_id)
|
|
.order_by(CitationRecord.queried_at.desc())
|
|
.offset(skip)
|
|
.limit(limit)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
async def count_by_query(self, query_id: uuid.UUID) -> int:
|
|
result = await self.session.execute(
|
|
select(func.count()).select_from(CitationRecord).where(
|
|
CitationRecord.query_id == query_id
|
|
)
|
|
)
|
|
return result.scalar_one()
|
|
|
|
async def get_by_query_and_platform(
|
|
self, query_id: uuid.UUID, platform: str
|
|
) -> Optional[CitationRecord]:
|
|
result = await self.session.execute(
|
|
select(CitationRecord).where(
|
|
CitationRecord.query_id == query_id,
|
|
CitationRecord.platform == platform,
|
|
)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def count_cited_by_brand(self, brand_name: str) -> int:
|
|
result = await self.session.execute(
|
|
select(func.count())
|
|
.select_from(CitationRecord)
|
|
.join(CitationRecord.query)
|
|
.where(
|
|
CitationRecord.cited == True,
|
|
)
|
|
)
|
|
return result.scalar_one()
|
|
|
|
async def create(self, **kwargs) -> CitationRecord:
|
|
instance = CitationRecord(**kwargs)
|
|
self.session.add(instance)
|
|
await self.session.flush()
|
|
return instance
|
|
|
|
async def update(self, id: uuid.UUID, **kwargs) -> Optional[CitationRecord]:
|
|
instance = await self.get_by_id(id)
|
|
if instance is None:
|
|
return None
|
|
for key, value in kwargs.items():
|
|
if hasattr(instance, key):
|
|
setattr(instance, key, value)
|
|
await self.session.flush()
|
|
return instance
|
|
|
|
async def delete(self, id: uuid.UUID) -> bool:
|
|
instance = await self.get_by_id(id)
|
|
if instance is None:
|
|
return False
|
|
await self.session.delete(instance)
|
|
await self.session.flush()
|
|
return True
|