153 lines
4.3 KiB
Python
153 lines
4.3 KiB
Python
import logging
|
|
import uuid
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
from sqlalchemy import and_, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.deps import get_current_user
|
|
from app.database import get_db
|
|
from app.models.brand import Brand
|
|
from app.models.user import User
|
|
from app.schemas.detection_task import (
|
|
DetectionTaskCreate,
|
|
DetectionTaskListResponse,
|
|
DetectionTaskResponse,
|
|
DetectionTaskUpdate,
|
|
DetectionTriggerResponse,
|
|
)
|
|
from app.services.detection.detection_scheduler import DetectionSchedulerService, TaskNotFoundError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
async def verify_brand_ownership(
|
|
brand_id: uuid.UUID,
|
|
current_user: User,
|
|
db: AsyncSession,
|
|
) -> Brand:
|
|
stmt = select(Brand).where(
|
|
and_(
|
|
Brand.id == brand_id,
|
|
Brand.user_id == current_user.id,
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
brand = result.scalar_one_or_none()
|
|
|
|
if not brand:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"品牌 {brand_id} 不存在或不属于当前用户",
|
|
)
|
|
|
|
return brand
|
|
|
|
|
|
@router.post("/tasks", response_model=DetectionTaskResponse, status_code=status.HTTP_201_CREATED)
|
|
async def create_detection_task(
|
|
data: DetectionTaskCreate,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
await verify_brand_ownership(data.brand_id, current_user, db)
|
|
|
|
service = DetectionSchedulerService()
|
|
try:
|
|
task = await service.create_task(
|
|
task_data=data.model_dump(exclude={"brand_id"}),
|
|
brand_id=data.brand_id,
|
|
user_id=current_user.id,
|
|
db=db,
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
detail=str(e),
|
|
)
|
|
|
|
return task
|
|
|
|
|
|
@router.get("/tasks", response_model=DetectionTaskListResponse)
|
|
async def get_detection_tasks(
|
|
brand_id: uuid.UUID = Query(..., description="按品牌筛选"),
|
|
skip: int = Query(0, ge=0),
|
|
limit: int = Query(20, ge=1, le=100),
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
await verify_brand_ownership(brand_id, current_user, db)
|
|
|
|
service = DetectionSchedulerService()
|
|
tasks = await service.get_tasks(brand_id, current_user.id, db)
|
|
|
|
return {"items": tasks[skip : skip + limit], "total": len(tasks)}
|
|
|
|
|
|
@router.put("/tasks/{task_id}", response_model=DetectionTaskResponse)
|
|
async def update_detection_task(
|
|
task_id: uuid.UUID,
|
|
data: DetectionTaskUpdate,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
service = DetectionSchedulerService()
|
|
try:
|
|
task = await service.update_task(
|
|
task_id=task_id,
|
|
task_data=data.model_dump(exclude_unset=True),
|
|
user_id=current_user.id,
|
|
db=db,
|
|
)
|
|
except TaskNotFoundError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="检测任务不存在",
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
detail=str(e),
|
|
)
|
|
|
|
return task
|
|
|
|
|
|
@router.delete("/tasks/{task_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def delete_detection_task(
|
|
task_id: uuid.UUID,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
service = DetectionSchedulerService()
|
|
result = await service.delete_task(task_id, current_user.id, db)
|
|
|
|
if not result:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="检测任务不存在",
|
|
)
|
|
|
|
return None
|
|
|
|
|
|
@router.post("/tasks/{task_id}/trigger", response_model=DetectionTriggerResponse)
|
|
async def trigger_detection_task(
|
|
task_id: uuid.UUID,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
service = DetectionSchedulerService()
|
|
result = await service.trigger_task(task_id, current_user.id, db)
|
|
|
|
if result.get("status") == "error":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=result.get("message", "检测任务不存在"),
|
|
)
|
|
|
|
return result
|