227 lines
6.9 KiB
Python
227 lines
6.9 KiB
Python
import logging
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
|
||
from sqlalchemy import and_, select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.models.brand import Brand
|
||
from app.models.detection_task import VALID_FREQUENCIES, DetectionTask
|
||
from app.services.ai_engine.base import EngineType
|
||
from app.services.ai_engine.batch_query import BatchQueryService, _build_adapters
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class TaskNotFoundError(Exception):
|
||
pass
|
||
|
||
|
||
class DetectionSchedulerService:
|
||
async def create_task(
|
||
self,
|
||
task_data: dict,
|
||
brand_id: uuid.UUID,
|
||
user_id: uuid.UUID,
|
||
db: AsyncSession,
|
||
) -> DetectionTask:
|
||
frequency = task_data.get("frequency", "daily")
|
||
self._validate_frequency(frequency)
|
||
|
||
task = DetectionTask(
|
||
brand_id=brand_id,
|
||
user_id=user_id,
|
||
name=task_data["name"],
|
||
frequency=frequency,
|
||
engines=task_data.get("engines", []),
|
||
queries=task_data.get("queries", []),
|
||
competitor_names=task_data.get("competitor_names"),
|
||
)
|
||
task.next_run_at = task.compute_next_run_at()
|
||
db.add(task)
|
||
await db.commit()
|
||
await db.refresh(task)
|
||
logger.info(f"Created detection task: {task.id}, name={task.name}")
|
||
return task
|
||
|
||
async def update_task(
|
||
self,
|
||
task_id: uuid.UUID,
|
||
task_data: dict,
|
||
user_id: uuid.UUID,
|
||
db: AsyncSession,
|
||
) -> DetectionTask:
|
||
task = await self._get_task_by_id(task_id, user_id, db)
|
||
if not task:
|
||
raise TaskNotFoundError(f"Detection task {task_id} not found")
|
||
|
||
frequency = task_data.get("frequency", task.frequency)
|
||
self._validate_frequency(frequency)
|
||
|
||
for field, value in task_data.items():
|
||
setattr(task, field, value)
|
||
|
||
task.next_run_at = task.compute_next_run_at()
|
||
await db.commit()
|
||
await db.refresh(task)
|
||
logger.info(f"Updated detection task: {task.id}")
|
||
return task
|
||
|
||
async def delete_task(
|
||
self,
|
||
task_id: uuid.UUID,
|
||
user_id: uuid.UUID,
|
||
db: AsyncSession,
|
||
) -> bool:
|
||
task = await self._get_task_by_id(task_id, user_id, db)
|
||
if not task:
|
||
return False
|
||
|
||
await db.delete(task)
|
||
await db.commit()
|
||
logger.info(f"Deleted detection task: {task_id}")
|
||
return True
|
||
|
||
async def get_tasks(
|
||
self,
|
||
brand_id: uuid.UUID,
|
||
user_id: uuid.UUID,
|
||
db: AsyncSession,
|
||
) -> list[DetectionTask]:
|
||
stmt = (
|
||
select(DetectionTask)
|
||
.where(
|
||
and_(
|
||
DetectionTask.brand_id == brand_id,
|
||
DetectionTask.user_id == user_id,
|
||
)
|
||
)
|
||
.order_by(DetectionTask.created_at.desc())
|
||
)
|
||
result = await db.execute(stmt)
|
||
return list(result.scalars().all())
|
||
|
||
async def trigger_task(
|
||
self,
|
||
task_id: uuid.UUID,
|
||
user_id: uuid.UUID,
|
||
db: AsyncSession,
|
||
) -> dict:
|
||
task = await self._get_task_by_id(task_id, user_id, db)
|
||
if not task:
|
||
return {"status": "error", "message": "Task not found"}
|
||
|
||
return await self.execute_task(task, db)
|
||
|
||
async def execute_task(
|
||
self,
|
||
task: DetectionTask,
|
||
db: AsyncSession,
|
||
) -> dict:
|
||
try:
|
||
results = await self._run_batch_query(task, db)
|
||
await self._generate_alerts_if_needed(task, results, db)
|
||
|
||
task.last_run_at = datetime.now(timezone.utc)
|
||
task.next_run_at = task.compute_next_run_at()
|
||
await db.commit()
|
||
await db.refresh(task)
|
||
|
||
logger.info(f"Detection task executed: {task.id}")
|
||
return {"status": "success", "results": results}
|
||
except Exception as e:
|
||
logger.error(f"Detection task execution failed: {task.id}, error={e}")
|
||
return {"status": "error", "message": str(e)}
|
||
|
||
async def _run_batch_query(
|
||
self,
|
||
task: DetectionTask,
|
||
db: AsyncSession,
|
||
) -> list:
|
||
adapters = _build_adapters()
|
||
batch_service = BatchQueryService(adapters)
|
||
|
||
all_results = []
|
||
brand = await self._get_brand(task.brand_id, db)
|
||
brand_name = brand.name if brand else "Unknown"
|
||
|
||
for query_text in task.queries:
|
||
engines = [
|
||
EngineType(e) for e in task.engines if e in EngineType._value2member_map_
|
||
]
|
||
results = await batch_service.query_batch(
|
||
engines=engines,
|
||
query=query_text,
|
||
brand_name=brand_name,
|
||
competitor_names=task.competitor_names,
|
||
)
|
||
all_results.extend(results)
|
||
|
||
return all_results
|
||
|
||
async def _generate_alerts_if_needed(
|
||
self,
|
||
task: DetectionTask,
|
||
results: list,
|
||
db: AsyncSession,
|
||
) -> list:
|
||
if not results:
|
||
return []
|
||
|
||
brand = await self._get_brand(task.brand_id, db)
|
||
if not brand:
|
||
return []
|
||
|
||
brand_cited = sum(1 for r in results if r.has_brand_citation)
|
||
competitor_cited = sum(1 for r in results if r.has_competitor_citation)
|
||
|
||
alerts = []
|
||
if competitor_cited > 0 and brand_cited == 0:
|
||
from app.services.alert_engine import AlertEngine
|
||
|
||
alert_engine = AlertEngine(db)
|
||
alert = await alert_engine._create_alert(
|
||
brand_id=task.brand_id,
|
||
user_id=task.user_id,
|
||
alert_type="competitor_overtake",
|
||
severity="warning",
|
||
title=f"{brand.name} 未被引用但竞品被引用",
|
||
message=f"检测任务「{task.name}」执行后发现:品牌未被AI引擎引用,但竞品被引用了 {competitor_cited} 次。",
|
||
data={"task_id": str(task.id), "competitor_cited": competitor_cited},
|
||
)
|
||
if alert:
|
||
alerts.append(alert)
|
||
|
||
return alerts
|
||
|
||
@staticmethod
|
||
def _validate_frequency(frequency: str) -> None:
|
||
if frequency not in VALID_FREQUENCIES:
|
||
raise ValueError(
|
||
f"Invalid frequency: {frequency}. Must be one of {VALID_FREQUENCIES}"
|
||
)
|
||
|
||
async def _get_task_by_id(
|
||
self,
|
||
task_id: uuid.UUID,
|
||
user_id: uuid.UUID,
|
||
db: AsyncSession,
|
||
) -> DetectionTask | None:
|
||
stmt = select(DetectionTask).where(
|
||
and_(
|
||
DetectionTask.id == task_id,
|
||
DetectionTask.user_id == user_id,
|
||
)
|
||
)
|
||
result = await db.execute(stmt)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def _get_brand(
|
||
self,
|
||
brand_id: uuid.UUID,
|
||
db: AsyncSession,
|
||
) -> Brand | None:
|
||
stmt = select(Brand).where(Brand.id == brand_id)
|
||
result = await db.execute(stmt)
|
||
return result.scalar_one_or_none()
|