geo/backend/app/workers/scheduler.py

202 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
定时任务调度器
- 使用 APScheduler 的 AsyncIOScheduler
- 每小时执行一次检查
- 查找 queries 表中 status='active' 且 next_query_at <= now() 的记录
- 为每个符合条件的 query 调用 CitationEngine 执行查询
"""
import asyncio
import logging
from datetime import datetime, timedelta, timezone
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.interval import IntervalTrigger
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import AsyncSessionLocal
from app.models.citation_record import CitationRecord
from app.models.query import Query
from app.models.query_task import QueryTask
from app.workers.citation_engine import CitationEngine
logger = logging.getLogger(__name__)
class QueryScheduler:
def __init__(self):
self.scheduler = AsyncIOScheduler()
self.engine = CitationEngine()
self._loop = None
def start(self):
"""启动调度器"""
self._loop = asyncio.get_event_loop()
self.scheduler.add_job(
self._run_check,
trigger=IntervalTrigger(hours=1),
id="check_queries",
name="检查并执行到期的查询任务",
replace_existing=True,
)
self.scheduler.add_job(
self._run_pending_tasks_check,
trigger=IntervalTrigger(minutes=1),
id="check_pending_tasks",
name="检查并执行遗留的pending查询任务",
replace_existing=True,
)
self.scheduler.start()
logger.info("查询调度器已启动每小时检查一次待执行任务每分钟检查一次遗留pending任务")
def _run_check(self):
"""同步包装:将异步检查任务调度到当前事件循环"""
if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(self.check_and_execute_queries(), self._loop)
else:
asyncio.run(self.check_and_execute_queries())
async def check_and_execute_queries(self):
"""检查并执行到期的查询任务"""
logger.info("开始检查待执行查询任务...")
async with AsyncSessionLocal() as db:
try:
now = datetime.now(timezone.utc)
stmt = select(Query).where(
Query.status == "active",
Query.next_query_at <= now,
)
result = await db.execute(stmt)
queries = result.scalars().all()
logger.info(f"找到 {len(queries)} 个待执行查询")
for query in queries:
try:
await self._execute_single_query(query, db)
except Exception as e:
logger.error(f"执行查询 {query.id} 失败: {e}")
continue
except Exception as e:
logger.error(f"检查查询任务时出错: {e}")
async def _execute_single_query(self, query: Query, db: AsyncSession):
"""执行单个查询"""
logger.info(f"开始执行查询: {query.keyword} (ID: {query.id})")
try:
await self.engine.execute_query(query, db)
logger.info(f"查询 {query.id} 执行完成")
except Exception as e:
logger.error(f"查询 {query.id} 执行失败: {e}")
raise
def _run_pending_tasks_check(self):
"""同步包装:将异步遗留任务检查调度到当前事件循环"""
if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(self.check_and_execute_pending_tasks(), self._loop)
else:
asyncio.run(self.check_and_execute_pending_tasks())
async def check_and_execute_pending_tasks(self):
"""兜底处理超过1分钟仍未执行的pending任务"""
logger.info("检查并执行遗留的 pending 查询任务...")
async with AsyncSessionLocal() as db:
try:
one_minute_ago = datetime.utcnow() - timedelta(minutes=1)
stmt = select(QueryTask).where(
QueryTask.status == "pending",
QueryTask.scheduled_at <= one_minute_ago,
)
result = await db.execute(stmt)
tasks = result.scalars().all()
logger.info(f"找到 {len(tasks)} 个遗留的 pending 任务")
from collections import defaultdict
tasks_by_query = defaultdict(list)
for task in tasks:
tasks_by_query[task.query_id].append(task)
for query_id, task_list in tasks_by_query.items():
query_stmt = select(Query).where(Query.id == query_id)
query_result = await db.execute(query_stmt)
query = query_result.scalar_one_or_none()
if not query or query.status != "active":
continue
for task in task_list:
try:
task.status = "running"
task.started_at = datetime.utcnow()
task.error_message = None
await db.commit()
citation_result = await self.engine.execute_single_platform(
keyword=query.keyword,
platform=task.platform,
target_brand=query.target_brand,
brand_aliases=query.brand_aliases or [],
)
if citation_result:
record = CitationRecord(
query_id=query_id,
platform=task.platform,
cited=citation_result.get("cited", False),
citation_position=citation_result.get("position"),
citation_text=citation_result.get("citation_text"),
competitor_brands=citation_result.get("competitor_brands", []),
raw_response=citation_result.get("raw_response", ""),
confidence=citation_result.get("confidence"),
match_type=citation_result.get("match_type"),
# 引用源分析字段
data_source=citation_result.get("data_source"),
source_urls=citation_result.get("source_urls"),
source_titles=citation_result.get("source_titles"),
citation_contexts=citation_result.get("citation_contexts"),
ai_response_text=citation_result.get("ai_response_text", ""),
)
db.add(record)
task.status = "success"
task.completed_at = datetime.utcnow()
await db.commit()
except Exception as e:
await db.rollback()
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
await db.commit()
logger.error(f"执行遗留任务 {task.id} 失败: {e}")
except Exception as e:
logger.error(f"检查遗留任务时出错: {e}")
async def shutdown(self):
"""关闭调度器"""
self.scheduler.shutdown(wait=False)
await self.engine.close()
logger.info("查询调度器已关闭")
# 全局调度器实例
query_scheduler = QueryScheduler()
# 导出别名以兼容测试
scheduler = query_scheduler.scheduler
def run_job_now(job_id: str):
"""手动触发指定任务"""
job = query_scheduler.scheduler.get_job(job_id)
if job:
# 获取任务的回调函数并直接调用
job.func()
return True
return False