645 lines
21 KiB
Python
645 lines
21 KiB
Python
"""
|
||
告警检测引擎 - 检测品牌可见性变化并生成告警
|
||
|
||
告警类型:
|
||
- score_drop: 评分下降超过阈值(默认5分)
|
||
- score_rise: 评分上升超过阈值(默认5分)
|
||
- negative_sentiment: 出现负面情感
|
||
- competitor_overtake: 竞品超越(竞品评分超过我方)
|
||
- new_platform_mention: 新平台出现提及
|
||
|
||
严重程度:
|
||
- critical: 需要立即关注(如评分大幅下降、负面情感)
|
||
- warning: 需要留意(如评分小幅下降、竞品接近)
|
||
- info: 一般信息(如评分上升、新平台提及)
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import uuid
|
||
from dataclasses import dataclass
|
||
from datetime import datetime, timedelta, timezone
|
||
from typing import Optional
|
||
|
||
from sqlalchemy import select, and_, func
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.models.alert import Alert
|
||
from app.models.alert_setting import AlertSetting
|
||
from app.models.brand import Brand
|
||
from app.models.citation_record import CitationRecord
|
||
from app.models.competitor import Competitor
|
||
from app.models.query import Query
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ============================================================
|
||
# 默认告警配置
|
||
# ============================================================
|
||
|
||
DEFAULT_ALERT_CONFIGS = {
|
||
"score_drop": {
|
||
"enabled": True,
|
||
"threshold": 5.0,
|
||
"severity": "warning",
|
||
},
|
||
"score_rise": {
|
||
"enabled": True,
|
||
"threshold": 5.0,
|
||
"severity": "info",
|
||
},
|
||
"negative_sentiment": {
|
||
"enabled": True,
|
||
"threshold": 1.0, # 出现1次负面即触发
|
||
"severity": "critical",
|
||
},
|
||
"competitor_overtake": {
|
||
"enabled": True,
|
||
"threshold": 0.0, # 只要超越即触发
|
||
"severity": "warning",
|
||
},
|
||
"new_platform_mention": {
|
||
"enabled": True,
|
||
"threshold": 1.0, # 新平台出现1次提及即触发
|
||
"severity": "info",
|
||
},
|
||
}
|
||
|
||
|
||
# ============================================================
|
||
# 数据结构
|
||
# ============================================================
|
||
|
||
@dataclass
|
||
class AlertContext:
|
||
"""告警检测上下文,包含检测所需的所有数据"""
|
||
brand_id: uuid.UUID
|
||
brand_name: str
|
||
user_id: uuid.UUID
|
||
current_score: float
|
||
previous_score: Optional[float]
|
||
sentiment_counts: dict[str, int] # {"positive": int, "neutral": int, "negative": int}
|
||
brand_mentions: int
|
||
competitor_mentions: dict[str, int] # {competitor_name: mention_count}
|
||
competitor_scores: dict[str, float] # {competitor_name: score}
|
||
current_platforms: set[str] # 当前已有提及的平台集合
|
||
new_platforms: set[str] # 新出现提及的平台集合
|
||
|
||
|
||
# ============================================================
|
||
# 告警检测引擎
|
||
# ============================================================
|
||
|
||
class AlertEngine:
|
||
"""告警检测引擎"""
|
||
|
||
def __init__(self, db: AsyncSession):
|
||
self.db = db
|
||
|
||
async def get_alert_setting(
|
||
self,
|
||
brand_id: uuid.UUID,
|
||
alert_type: str,
|
||
) -> AlertSetting | None:
|
||
"""获取指定品牌和告警类型的设置"""
|
||
stmt = select(AlertSetting).where(
|
||
and_(
|
||
AlertSetting.brand_id == brand_id,
|
||
AlertSetting.alert_type == alert_type,
|
||
)
|
||
)
|
||
result = await self.db.execute(stmt)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def is_alert_enabled(
|
||
self,
|
||
brand_id: uuid.UUID,
|
||
alert_type: str,
|
||
) -> bool:
|
||
"""检查告警是否启用"""
|
||
setting = await self.get_alert_setting(brand_id, alert_type)
|
||
if setting is None:
|
||
# 没有设置记录时使用默认值
|
||
return DEFAULT_ALERT_CONFIGS.get(alert_type, {}).get("enabled", True)
|
||
return setting.enabled
|
||
|
||
async def get_threshold(
|
||
self,
|
||
brand_id: uuid.UUID,
|
||
alert_type: str,
|
||
) -> float:
|
||
"""获取告警阈值"""
|
||
setting = await self.get_alert_setting(brand_id, alert_type)
|
||
if setting is not None and setting.threshold is not None:
|
||
return setting.threshold
|
||
return DEFAULT_ALERT_CONFIGS.get(alert_type, {}).get("threshold", 5.0)
|
||
|
||
async def _create_alert(
|
||
self,
|
||
brand_id: uuid.UUID,
|
||
user_id: uuid.UUID,
|
||
alert_type: str,
|
||
severity: str,
|
||
title: str,
|
||
message: str,
|
||
data: dict | None = None,
|
||
) -> Alert | None:
|
||
"""创建告警记录(内部方法)"""
|
||
# 检查是否启用
|
||
if not await self.is_alert_enabled(brand_id, alert_type):
|
||
logger.debug(f"告警类型 {alert_type} 对品牌 {brand_id} 已禁用,跳过")
|
||
return None
|
||
|
||
# 防止重复告警:检查最近1小时内是否已有相同类型的告警
|
||
one_hour_ago = datetime.now(timezone.utc) - timedelta(hours=1)
|
||
existing_stmt = select(Alert).where(
|
||
and_(
|
||
Alert.brand_id == brand_id,
|
||
Alert.alert_type == alert_type,
|
||
Alert.created_at >= one_hour_ago,
|
||
)
|
||
)
|
||
existing_result = await self.db.execute(existing_stmt)
|
||
existing = existing_result.scalar_one_or_none()
|
||
|
||
if existing:
|
||
logger.debug(f"品牌 {brand_id} 最近1小时已有 {alert_type} 告警,跳过重复创建")
|
||
return None
|
||
|
||
alert = Alert(
|
||
brand_id=brand_id,
|
||
user_id=user_id,
|
||
alert_type=alert_type,
|
||
severity=severity,
|
||
title=title,
|
||
message=message,
|
||
data=data,
|
||
is_read=False,
|
||
)
|
||
self.db.add(alert)
|
||
await self.db.flush()
|
||
|
||
logger.info(
|
||
f"创建告警: brand={brand_id}, type={alert_type}, "
|
||
f"severity={severity}, title={title}"
|
||
)
|
||
return alert
|
||
|
||
# ============================================================
|
||
# 五种告警检测方法
|
||
# ============================================================
|
||
|
||
async def check_score_drop(self, ctx: AlertContext) -> Alert | None:
|
||
"""
|
||
检测评分下降
|
||
|
||
当评分下降超过阈值时触发,严重程度根据下降幅度判断:
|
||
- 下降超过20分: critical
|
||
- 下降超过阈值: warning
|
||
"""
|
||
if ctx.previous_score is None:
|
||
return None
|
||
|
||
threshold = await self.get_threshold(ctx.brand_id, "score_drop")
|
||
drop = ctx.previous_score - ctx.current_score
|
||
|
||
if drop < threshold:
|
||
return None
|
||
|
||
severity = "critical" if drop >= 20 else "warning"
|
||
|
||
return await self._create_alert(
|
||
brand_id=ctx.brand_id,
|
||
user_id=ctx.user_id,
|
||
alert_type="score_drop",
|
||
severity=severity,
|
||
title=f"{ctx.brand_name} 评分下降 {drop:.1f} 分",
|
||
message=(
|
||
f"品牌「{ctx.brand_name}」的可见性评分从 {ctx.previous_score:.1f} "
|
||
f"下降至 {ctx.current_score:.1f},下降了 {drop:.1f} 分。"
|
||
f"请关注近期品牌在各AI平台的表现变化。"
|
||
),
|
||
data={
|
||
"previous_score": ctx.previous_score,
|
||
"current_score": ctx.current_score,
|
||
"drop": round(drop, 2),
|
||
},
|
||
)
|
||
|
||
async def check_score_rise(self, ctx: AlertContext) -> Alert | None:
|
||
"""
|
||
检测评分上升
|
||
|
||
当评分上升超过阈值时触发,severity 固定为 info
|
||
"""
|
||
if ctx.previous_score is None:
|
||
return None
|
||
|
||
threshold = await self.get_threshold(ctx.brand_id, "score_rise")
|
||
rise = ctx.current_score - ctx.previous_score
|
||
|
||
if rise < threshold:
|
||
return None
|
||
|
||
return await self._create_alert(
|
||
brand_id=ctx.brand_id,
|
||
user_id=ctx.user_id,
|
||
alert_type="score_rise",
|
||
severity="info",
|
||
title=f"{ctx.brand_name} 评分上升 {rise:.1f} 分",
|
||
message=(
|
||
f"品牌「{ctx.brand_name}」的可见性评分从 {ctx.previous_score:.1f} "
|
||
f"上升至 {ctx.current_score:.1f},上升了 {rise:.1f} 分。"
|
||
f"品牌在AI平台的表现有所提升。"
|
||
),
|
||
data={
|
||
"previous_score": ctx.previous_score,
|
||
"current_score": ctx.current_score,
|
||
"rise": round(rise, 2),
|
||
},
|
||
)
|
||
|
||
async def check_negative_sentiment(self, ctx: AlertContext) -> Alert | None:
|
||
"""
|
||
检测负面情感
|
||
|
||
当出现负面情感时触发,严重程度根据负面数量判断:
|
||
- 负面数量 >= 3: critical
|
||
- 负面数量 >= 1: warning
|
||
"""
|
||
negative_count = ctx.sentiment_counts.get("negative", 0)
|
||
threshold = await self.get_threshold(ctx.brand_id, "negative_sentiment")
|
||
|
||
if negative_count < threshold:
|
||
return None
|
||
|
||
severity = "critical" if negative_count >= 3 else "warning"
|
||
|
||
total = sum(ctx.sentiment_counts.values())
|
||
negative_rate = (negative_count / total * 100) if total > 0 else 0
|
||
|
||
return await self._create_alert(
|
||
brand_id=ctx.brand_id,
|
||
user_id=ctx.user_id,
|
||
alert_type="negative_sentiment",
|
||
severity=severity,
|
||
title=f"{ctx.brand_name} 检测到负面情感",
|
||
message=(
|
||
f"品牌「{ctx.brand_name}」在AI回答中检测到 {negative_count} 条负面提及"
|
||
f"(占比 {negative_rate:.1f}%)。"
|
||
f"请关注AI平台对品牌的负面评价内容,及时采取应对措施。"
|
||
),
|
||
data={
|
||
"negative_count": negative_count,
|
||
"total_count": total,
|
||
"negative_rate": round(negative_rate, 2),
|
||
"sentiment_counts": ctx.sentiment_counts,
|
||
},
|
||
)
|
||
|
||
async def check_competitor_overtake(self, ctx: AlertContext) -> Alert | None:
|
||
"""
|
||
检测竞品超越
|
||
|
||
当竞品评分超过我方时触发,严重程度根据差距判断:
|
||
- 竞品领先超过10分: critical
|
||
- 竞品领先: warning
|
||
"""
|
||
if not ctx.competitor_scores:
|
||
return None
|
||
|
||
overtake_competitors = []
|
||
for comp_name, comp_score in ctx.competitor_scores.items():
|
||
if comp_score > ctx.current_score:
|
||
overtake_competitors.append({
|
||
"name": comp_name,
|
||
"score": comp_score,
|
||
"gap": round(comp_score - ctx.current_score, 2),
|
||
})
|
||
|
||
if not overtake_competitors:
|
||
return None
|
||
|
||
# 找出差距最大的竞品
|
||
max_gap_competitor = max(overtake_competitors, key=lambda x: x["gap"])
|
||
severity = "critical" if max_gap_competitor["gap"] >= 10 else "warning"
|
||
|
||
comp_names = ", ".join(c["name"] for c in overtake_competitors)
|
||
|
||
return await self._create_alert(
|
||
brand_id=ctx.brand_id,
|
||
user_id=ctx.user_id,
|
||
alert_type="competitor_overtake",
|
||
severity=severity,
|
||
title=f"竞品 {comp_names} 评分已超越 {ctx.brand_name}",
|
||
message=(
|
||
f"品牌「{ctx.brand_name}」当前评分 {ctx.current_score:.1f},"
|
||
f"已被竞品「{comp_names}」超越。"
|
||
f"最大差距为 {max_gap_competitor['gap']:.1f} 分"
|
||
f"({max_gap_competitor['name']}:{max_gap_competitor['score']:.1f})。"
|
||
f"建议关注竞品策略,优化品牌在AI平台的表现。"
|
||
),
|
||
data={
|
||
"brand_score": ctx.current_score,
|
||
"overtake_competitors": overtake_competitors,
|
||
},
|
||
)
|
||
|
||
async def check_new_platform_mention(self, ctx: AlertContext) -> Alert | None:
|
||
"""
|
||
检测新平台出现提及
|
||
|
||
当品牌在之前未出现的平台上被提及时触发,severity 固定为 info
|
||
"""
|
||
if not ctx.new_platforms:
|
||
return None
|
||
|
||
platform_names = ", ".join(ctx.new_platforms)
|
||
|
||
# 平台名称映射
|
||
platform_labels = {
|
||
"wenxin": "文心一言",
|
||
"kimi": "Kimi",
|
||
"tongyi": "通义千问",
|
||
"doubao": "豆包",
|
||
"xinghuo": "讯飞星火",
|
||
"tiangong": "天工AI",
|
||
"qingyan": "智谱清言",
|
||
"search_engine": "搜索引擎",
|
||
}
|
||
|
||
display_names = [
|
||
platform_labels.get(p, p) for p in ctx.new_platforms
|
||
]
|
||
display_text = ", ".join(display_names)
|
||
|
||
return await self._create_alert(
|
||
brand_id=ctx.brand_id,
|
||
user_id=ctx.user_id,
|
||
alert_type="new_platform_mention",
|
||
severity="info",
|
||
title=f"{ctx.brand_name} 在新平台 {display_text} 被提及",
|
||
message=(
|
||
f"品牌「{ctx.brand_name}」在新的AI平台「{display_text}」中首次被提及。"
|
||
f"这是一个积极的信号,表明品牌的影响力正在扩展到更多平台。"
|
||
),
|
||
data={
|
||
"new_platforms": list(ctx.new_platforms),
|
||
"existing_platforms": list(ctx.current_platforms - ctx.new_platforms),
|
||
},
|
||
)
|
||
|
||
# ============================================================
|
||
# 综合检测入口
|
||
# ============================================================
|
||
|
||
async def detect_all(self, ctx: AlertContext) -> list[Alert]:
|
||
"""
|
||
执行所有告警检测
|
||
|
||
Args:
|
||
ctx: 告警检测上下文
|
||
|
||
Returns:
|
||
生成的告警列表
|
||
"""
|
||
alerts: list[Alert] = []
|
||
|
||
check_methods = [
|
||
self.check_score_drop,
|
||
self.check_score_rise,
|
||
self.check_negative_sentiment,
|
||
self.check_competitor_overtake,
|
||
self.check_new_platform_mention,
|
||
]
|
||
|
||
for check_method in check_methods:
|
||
try:
|
||
alert = await check_method(ctx)
|
||
if alert is not None:
|
||
alerts.append(alert)
|
||
except Exception as e:
|
||
logger.error(
|
||
f"告警检测失败: method={check_method.__name__}, "
|
||
f"brand={ctx.brand_id}, error={e}"
|
||
)
|
||
|
||
return alerts
|
||
|
||
# ============================================================
|
||
# 便捷方法:从评分结果构建上下文并检测
|
||
# ============================================================
|
||
|
||
async def detect_after_scoring(
|
||
self,
|
||
brand_id: uuid.UUID,
|
||
brand_name: str,
|
||
user_id: uuid.UUID,
|
||
current_score: float,
|
||
sentiment_counts: dict[str, int],
|
||
brand_mentions: int,
|
||
competitor_mentions: dict[str, int],
|
||
competitor_scores: dict[str, float] | None = None,
|
||
current_platforms: set[str] | None = None,
|
||
new_platforms: set[str] | None = None,
|
||
) -> list[Alert]:
|
||
"""
|
||
评分计算后执行告警检测
|
||
|
||
自动获取上一次评分进行对比,构建告警上下文后执行所有检测。
|
||
|
||
Args:
|
||
brand_id: 品牌ID
|
||
brand_name: 品牌名称
|
||
user_id: 用户ID
|
||
current_score: 当前评分
|
||
sentiment_counts: 情感分布统计
|
||
brand_mentions: 品牌提及次数
|
||
competitor_mentions: 竞品提及次数
|
||
competitor_scores: 竞品评分
|
||
current_platforms: 当前已有提及的平台集合
|
||
new_platforms: 新出现提及的平台集合
|
||
|
||
Returns:
|
||
生成的告警列表
|
||
"""
|
||
# 获取上一次评分(从最近的历史告警数据中获取,或从评分历史获取)
|
||
previous_score = await self._get_previous_score(brand_id, user_id)
|
||
|
||
ctx = AlertContext(
|
||
brand_id=brand_id,
|
||
brand_name=brand_name,
|
||
user_id=user_id,
|
||
current_score=current_score,
|
||
previous_score=previous_score,
|
||
sentiment_counts=sentiment_counts,
|
||
brand_mentions=brand_mentions,
|
||
competitor_mentions=competitor_mentions,
|
||
competitor_scores=competitor_scores or {},
|
||
current_platforms=current_platforms or set(),
|
||
new_platforms=new_platforms or set(),
|
||
)
|
||
|
||
return await self.detect_all(ctx)
|
||
|
||
async def _get_previous_score(
|
||
self,
|
||
brand_id: uuid.UUID,
|
||
user_id: uuid.UUID,
|
||
) -> float | None:
|
||
"""
|
||
获取品牌上一次评分
|
||
|
||
从 score_drop 或 score_rise 类型的告警数据中获取,
|
||
如果没有则从引用记录中估算。
|
||
"""
|
||
# 尝试从最近的评分变化告警中获取
|
||
recent_alert_stmt = (
|
||
select(Alert)
|
||
.where(
|
||
and_(
|
||
Alert.brand_id == brand_id,
|
||
Alert.alert_type.in_(["score_drop", "score_rise"]),
|
||
)
|
||
)
|
||
.order_by(Alert.created_at.desc())
|
||
.limit(1)
|
||
)
|
||
result = await self.db.execute(recent_alert_stmt)
|
||
recent_alert = result.scalar_one_or_none()
|
||
|
||
if recent_alert and recent_alert.data:
|
||
return recent_alert.data.get("current_score")
|
||
|
||
# 没有历史告警,尝试从引用记录估算历史评分
|
||
return await self._estimate_previous_score(brand_id, user_id)
|
||
|
||
async def _estimate_previous_score(
|
||
self,
|
||
brand_id: uuid.UUID,
|
||
user_id: uuid.UUID,
|
||
) -> float | None:
|
||
"""
|
||
从历史引用记录估算上一次评分
|
||
|
||
使用昨天的引用数据计算一个简化的评分
|
||
"""
|
||
brand_stmt = select(Brand).where(Brand.id == brand_id)
|
||
brand_result = await self.db.execute(brand_stmt)
|
||
brand = brand_result.scalar_one_or_none()
|
||
|
||
if not brand:
|
||
return None
|
||
|
||
# 获取昨天的引用数据
|
||
yesterday = datetime.now(timezone.utc).date() - timedelta(days=1)
|
||
today_start = datetime.combine(yesterday, datetime.min.time()).replace(
|
||
tzinfo=timezone.utc
|
||
)
|
||
today_end = datetime.combine(
|
||
yesterday + timedelta(days=1), datetime.min.time()
|
||
).replace(tzinfo=timezone.utc)
|
||
|
||
queries_stmt = select(Query).where(
|
||
and_(
|
||
Query.user_id == user_id,
|
||
Query.target_brand == brand.name,
|
||
)
|
||
)
|
||
queries_result = await self.db.execute(queries_stmt)
|
||
queries = list(queries_result.scalars().all())
|
||
|
||
if not queries:
|
||
return None
|
||
|
||
query_ids = [q.id for q in queries]
|
||
|
||
citations_stmt = select(CitationRecord).where(
|
||
and_(
|
||
CitationRecord.query_id.in_(query_ids),
|
||
CitationRecord.queried_at >= today_start,
|
||
CitationRecord.queried_at < today_end,
|
||
)
|
||
)
|
||
citations_result = await self.db.execute(citations_stmt)
|
||
citations = list(citations_result.scalars().all())
|
||
|
||
if not citations:
|
||
return None
|
||
|
||
# 简化评分:提及率 * 100
|
||
cited_count = sum(1 for c in citations if c.cited)
|
||
total_count = len(citations)
|
||
if total_count > 0:
|
||
return round((cited_count / total_count) * 100, 2)
|
||
|
||
return None
|
||
|
||
# ============================================================
|
||
# 告警清理
|
||
# ============================================================
|
||
|
||
async def cleanup_old_alerts(self, days: int = 90) -> int:
|
||
"""
|
||
清理超过指定天数的旧告警
|
||
|
||
Args:
|
||
days: 保留天数,默认90天
|
||
|
||
Returns:
|
||
删除的告警数量
|
||
"""
|
||
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
|
||
stmt = select(Alert).where(Alert.created_at < cutoff)
|
||
result = await self.db.execute(stmt)
|
||
old_alerts = list(result.scalars().all())
|
||
|
||
count = len(old_alerts)
|
||
for alert in old_alerts:
|
||
await self.db.delete(alert)
|
||
|
||
await self.db.flush()
|
||
|
||
if count > 0:
|
||
logger.info(f"已清理 {count} 条超过 {days} 天的旧告警")
|
||
|
||
return count
|
||
|
||
# ============================================================
|
||
# 初始化默认告警设置
|
||
# ============================================================
|
||
|
||
async def ensure_default_settings(
|
||
self,
|
||
brand_id: uuid.UUID,
|
||
user_id: uuid.UUID,
|
||
) -> list[AlertSetting]:
|
||
"""
|
||
确保品牌有默认的告警设置
|
||
|
||
如果某种告警类型没有设置记录,则创建默认设置。
|
||
|
||
Returns:
|
||
所有告警设置列表
|
||
"""
|
||
settings: list[AlertSetting] = []
|
||
|
||
for alert_type, config in DEFAULT_ALERT_CONFIGS.items():
|
||
existing = await self.get_alert_setting(brand_id, alert_type)
|
||
if existing is None:
|
||
setting = AlertSetting(
|
||
brand_id=brand_id,
|
||
user_id=user_id,
|
||
alert_type=alert_type,
|
||
enabled=config["enabled"],
|
||
threshold=config["threshold"],
|
||
)
|
||
self.db.add(setting)
|
||
settings.append(setting)
|
||
else:
|
||
settings.append(existing)
|
||
|
||
await self.db.flush()
|
||
return settings
|