123 lines
3.5 KiB
Python
123 lines
3.5 KiB
Python
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from sqlalchemy import delete, func, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.query import Query
|
|
from app.models.user import User
|
|
from app.schemas.query import QueryCreate, QueryUpdate
|
|
|
|
|
|
async def get_queries(
|
|
db: AsyncSession,
|
|
user_id: uuid.UUID,
|
|
skip: int = 0,
|
|
limit: int = 20,
|
|
) -> tuple[list[Query], int]:
|
|
stmt = (
|
|
select(Query)
|
|
.where(Query.user_id == user_id)
|
|
.order_by(Query.created_at.desc())
|
|
.offset(skip)
|
|
.limit(limit)
|
|
)
|
|
result = await db.execute(stmt)
|
|
items = result.scalars().all()
|
|
|
|
count_stmt = select(func.count()).select_from(Query).where(Query.user_id == user_id)
|
|
count_result = await db.execute(count_stmt)
|
|
total = count_result.scalar_one()
|
|
|
|
return list(items), total
|
|
|
|
|
|
async def get_query(
|
|
db: AsyncSession,
|
|
query_id: uuid.UUID,
|
|
user_id: uuid.UUID,
|
|
) -> Query | None:
|
|
stmt = select(Query).where(Query.id == query_id, Query.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def create_query(
|
|
db: AsyncSession,
|
|
user_id: uuid.UUID,
|
|
query_data: QueryCreate,
|
|
) -> Query:
|
|
# Check user's current query count against max_queries limit
|
|
count_stmt = select(func.count()).select_from(Query).where(Query.user_id == user_id)
|
|
count_result = await db.execute(count_stmt)
|
|
current_count = count_result.scalar_one()
|
|
|
|
user_stmt = select(User).where(User.id == user_id)
|
|
user_result = await db.execute(user_stmt)
|
|
user = user_result.scalar_one()
|
|
|
|
if current_count >= user.max_queries:
|
|
raise PermissionError("Query limit exceeded")
|
|
|
|
# Calculate next_query_at based on frequency (use naive datetime for DB compatibility)
|
|
now = datetime.utcnow()
|
|
if query_data.frequency == "daily":
|
|
next_query_at = now + timedelta(days=1)
|
|
else: # weekly
|
|
next_query_at = now + timedelta(days=7)
|
|
|
|
query = Query(
|
|
user_id=user_id,
|
|
keyword=query_data.keyword,
|
|
target_brand=query_data.target_brand,
|
|
brand_aliases=query_data.brand_aliases or [],
|
|
platforms=query_data.platforms,
|
|
frequency=query_data.frequency,
|
|
next_query_at=next_query_at,
|
|
)
|
|
db.add(query)
|
|
await db.commit()
|
|
await db.refresh(query)
|
|
return query
|
|
|
|
|
|
async def update_query(
|
|
db: AsyncSession,
|
|
query_id: uuid.UUID,
|
|
user_id: uuid.UUID,
|
|
update_data: QueryUpdate,
|
|
) -> Query | None:
|
|
stmt = select(Query).where(Query.id == query_id, Query.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
query = result.scalar_one_or_none()
|
|
if query is None:
|
|
return None
|
|
|
|
update_dict = update_data.model_dump(exclude_unset=True)
|
|
|
|
# Recalculate next_query_at if frequency is updated
|
|
if "frequency" in update_dict:
|
|
now = datetime.utcnow()
|
|
if update_dict["frequency"] == "daily":
|
|
query.next_query_at = now + timedelta(days=1)
|
|
else: # weekly
|
|
query.next_query_at = now + timedelta(days=7)
|
|
|
|
for field, value in update_dict.items():
|
|
setattr(query, field, value)
|
|
|
|
await db.commit()
|
|
await db.refresh(query)
|
|
return query
|
|
|
|
|
|
async def delete_query(
|
|
db: AsyncSession,
|
|
query_id: uuid.UUID,
|
|
user_id: uuid.UUID,
|
|
) -> bool:
|
|
stmt = delete(Query).where(Query.id == query_id, Query.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
await db.commit()
|
|
return result.rowcount > 0
|