geo/backend/app/api/detection.py

153 lines
4.3 KiB
Python

import logging
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.brand import Brand
from app.models.user import User
from app.schemas.detection_task import (
DetectionTaskCreate,
DetectionTaskListResponse,
DetectionTaskResponse,
DetectionTaskUpdate,
DetectionTriggerResponse,
)
from app.services.detection.detection_scheduler import DetectionSchedulerService, TaskNotFoundError
logger = logging.getLogger(__name__)
router = APIRouter()
async def verify_brand_ownership(
brand_id: uuid.UUID,
current_user: User,
db: AsyncSession,
) -> Brand:
stmt = select(Brand).where(
and_(
Brand.id == brand_id,
Brand.user_id == current_user.id,
)
)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"品牌 {brand_id} 不存在或不属于当前用户",
)
return brand
@router.post("/tasks", response_model=DetectionTaskResponse, status_code=status.HTTP_201_CREATED)
async def create_detection_task(
data: DetectionTaskCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await verify_brand_ownership(data.brand_id, current_user, db)
service = DetectionSchedulerService()
try:
task = await service.create_task(
task_data=data.model_dump(exclude={"brand_id"}),
brand_id=data.brand_id,
user_id=current_user.id,
db=db,
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=str(e),
)
return task
@router.get("/tasks", response_model=DetectionTaskListResponse)
async def get_detection_tasks(
brand_id: uuid.UUID = Query(..., description="按品牌筛选"),
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await verify_brand_ownership(brand_id, current_user, db)
service = DetectionSchedulerService()
tasks = await service.get_tasks(brand_id, current_user.id, db)
return {"items": tasks[skip : skip + limit], "total": len(tasks)}
@router.put("/tasks/{task_id}", response_model=DetectionTaskResponse)
async def update_detection_task(
task_id: uuid.UUID,
data: DetectionTaskUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
service = DetectionSchedulerService()
try:
task = await service.update_task(
task_id=task_id,
task_data=data.model_dump(exclude_unset=True),
user_id=current_user.id,
db=db,
)
except TaskNotFoundError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="检测任务不存在",
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=str(e),
)
return task
@router.delete("/tasks/{task_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_detection_task(
task_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
service = DetectionSchedulerService()
result = await service.delete_task(task_id, current_user.id, db)
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="检测任务不存在",
)
return None
@router.post("/tasks/{task_id}/trigger", response_model=DetectionTriggerResponse)
async def trigger_detection_task(
task_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
service = DetectionSchedulerService()
result = await service.trigger_task(task_id, current_user.id, db)
if result.get("status") == "error":
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=result.get("message", "检测任务不存在"),
)
return result