geo/backend/app/services/query.py

123 lines
3.5 KiB
Python

import uuid
from datetime import datetime, timedelta, timezone
from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.query import Query
from app.models.user import User
from app.schemas.query import QueryCreate, QueryUpdate
async def get_queries(
db: AsyncSession,
user_id: uuid.UUID,
skip: int = 0,
limit: int = 20,
) -> tuple[list[Query], int]:
stmt = (
select(Query)
.where(Query.user_id == user_id)
.order_by(Query.created_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(stmt)
items = result.scalars().all()
count_stmt = select(func.count()).select_from(Query).where(Query.user_id == user_id)
count_result = await db.execute(count_stmt)
total = count_result.scalar_one()
return list(items), total
async def get_query(
db: AsyncSession,
query_id: uuid.UUID,
user_id: uuid.UUID,
) -> Query | None:
stmt = select(Query).where(Query.id == query_id, Query.user_id == user_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def create_query(
db: AsyncSession,
user_id: uuid.UUID,
query_data: QueryCreate,
) -> Query:
# Check user's current query count against max_queries limit
count_stmt = select(func.count()).select_from(Query).where(Query.user_id == user_id)
count_result = await db.execute(count_stmt)
current_count = count_result.scalar_one()
user_stmt = select(User).where(User.id == user_id)
user_result = await db.execute(user_stmt)
user = user_result.scalar_one()
if current_count >= user.max_queries:
raise PermissionError("Query limit exceeded")
# Calculate next_query_at based on frequency (use naive datetime for DB compatibility)
now = datetime.utcnow()
if query_data.frequency == "daily":
next_query_at = now + timedelta(days=1)
else: # weekly
next_query_at = now + timedelta(days=7)
query = Query(
user_id=user_id,
keyword=query_data.keyword,
target_brand=query_data.target_brand,
brand_aliases=query_data.brand_aliases or [],
platforms=query_data.platforms,
frequency=query_data.frequency,
next_query_at=next_query_at,
)
db.add(query)
await db.commit()
await db.refresh(query)
return query
async def update_query(
db: AsyncSession,
query_id: uuid.UUID,
user_id: uuid.UUID,
update_data: QueryUpdate,
) -> Query | None:
stmt = select(Query).where(Query.id == query_id, Query.user_id == user_id)
result = await db.execute(stmt)
query = result.scalar_one_or_none()
if query is None:
return None
update_dict = update_data.model_dump(exclude_unset=True)
# Recalculate next_query_at if frequency is updated
if "frequency" in update_dict:
now = datetime.utcnow()
if update_dict["frequency"] == "daily":
query.next_query_at = now + timedelta(days=1)
else: # weekly
query.next_query_at = now + timedelta(days=7)
for field, value in update_dict.items():
setattr(query, field, value)
await db.commit()
await db.refresh(query)
return query
async def delete_query(
db: AsyncSession,
query_id: uuid.UUID,
user_id: uuid.UUID,
) -> bool:
stmt = delete(Query).where(Query.id == query_id, Query.user_id == user_id)
result = await db.execute(stmt)
await db.commit()
return result.rowcount > 0