113 lines
3.4 KiB
Python
113 lines
3.4 KiB
Python
import uuid
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.base import PaginationParams, PaginatedResponse
|
|
from app.api.deps import get_current_user
|
|
from app.database import get_db
|
|
from app.models.user import User
|
|
from app.schemas.citation import RunNowResponse
|
|
from app.schemas.query import QueryCreate, QueryListResponse, QueryResponse, QueryUpdate
|
|
from app.services.citation.citation import trigger_query_now
|
|
from app.services.query import create_query, delete_query, get_queries, get_query, update_query
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("/", response_model=QueryListResponse)
|
|
async def list_queries(
|
|
pagination: PaginationParams = Depends(PaginationParams),
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
items, total = await get_queries(
|
|
db, current_user.id,
|
|
skip=pagination.offset,
|
|
limit=pagination.limit,
|
|
)
|
|
return {"items": items, "total": total}
|
|
|
|
|
|
@router.post("/", response_model=QueryResponse, status_code=status.HTTP_201_CREATED)
|
|
async def create_new_query(
|
|
query_data: QueryCreate,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
try:
|
|
query = await create_query(db, current_user.id, query_data)
|
|
except PermissionError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=str(e) if str(e) else "查询词数量已达上限",
|
|
)
|
|
return query
|
|
|
|
|
|
@router.get("/{query_id}", response_model=QueryResponse)
|
|
async def retrieve_query(
|
|
query_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
query = await get_query(db, query_id, current_user.id)
|
|
if query is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="查询词不存在",
|
|
)
|
|
return query
|
|
|
|
|
|
@router.put("/{query_id}", response_model=QueryResponse)
|
|
async def modify_query(
|
|
query_id: uuid.UUID,
|
|
update_data: QueryUpdate,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
query = await update_query(db, query_id, current_user.id, update_data)
|
|
if query is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="查询词不存在",
|
|
)
|
|
return query
|
|
|
|
|
|
@router.delete("/{query_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def remove_query(
|
|
query_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
deleted = await delete_query(db, query_id, current_user.id)
|
|
if not deleted:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="查询词不存在",
|
|
)
|
|
return None
|
|
|
|
|
|
@router.post("/{query_id}/run-now", response_model=RunNowResponse, status_code=status.HTTP_202_ACCEPTED)
|
|
async def run_query_now(
|
|
query_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
try:
|
|
task = await trigger_query_now(db, current_user.id, query_id)
|
|
except ValueError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=str(e),
|
|
)
|
|
|
|
return {
|
|
"task_id": task.id,
|
|
"status": task.status,
|
|
"message": "查询任务已加入队列",
|
|
}
|