359 lines
13 KiB
Python
359 lines
13 KiB
Python
import difflib
|
||
import logging
|
||
import re
|
||
import uuid
|
||
from datetime import datetime, timedelta
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select
|
||
|
||
|
||
def _sanitize_raw_response(text: str | None) -> str:
|
||
"""清理原始响应中的无效控制字符,避免 PostgreSQL UTF-8 插入失败"""
|
||
if not text:
|
||
return ""
|
||
# 移除 NULL 字节及其他非法控制字符,保留 \n \t \r
|
||
return re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f]", "", text)
|
||
|
||
from app.models.citation_record import CitationRecord
|
||
from app.models.query import Query
|
||
from app.models.query_task import QueryTask
|
||
from app.workers.platforms.kimi import KimiAdapter
|
||
from app.workers.platforms.wenxin import WenxinAdapter
|
||
from app.workers.platforms.tongyi import TongyiAdapter
|
||
from app.workers.platforms.doubao import DoubaoAdapter
|
||
from app.workers.platforms.qingyan import QingyanAdapter
|
||
from app.workers.platforms.tiangong import TiangongAdapter
|
||
from app.workers.platforms.xinghuo import XinghuoAdapter
|
||
from app.workers.citation_extractor import analyze_citations
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class BrandMatcher:
|
||
"""品牌匹配器:检测文本中是否引用了目标品牌"""
|
||
|
||
def __init__(self, target_brand: str, brand_aliases: list[str] | None = None):
|
||
self.target_brand = target_brand
|
||
self.brand_aliases = brand_aliases or []
|
||
|
||
def match(self, text: str) -> dict:
|
||
"""
|
||
返回: {
|
||
"cited": bool,
|
||
"confidence": float, # 0.0-1.0
|
||
"match_type": str, # "exact"/"alias"/"fuzzy"/None
|
||
"position": int|None, # 在文本段落中的位置(第几段提到,1-based)
|
||
"citation_text": str|None, # 被引用的上下文片段
|
||
}
|
||
"""
|
||
if not text:
|
||
return {
|
||
"cited": False,
|
||
"confidence": 0.0,
|
||
"match_type": None,
|
||
"position": None,
|
||
"citation_text": None,
|
||
}
|
||
|
||
# 1. 精确匹配
|
||
if self.target_brand in text:
|
||
position, citation_text = self._extract_position_and_context(text, self.target_brand)
|
||
return {
|
||
"cited": True,
|
||
"confidence": 1.0,
|
||
"match_type": "exact",
|
||
"position": position,
|
||
"citation_text": citation_text,
|
||
}
|
||
|
||
# 2. 别名匹配
|
||
for alias in self.brand_aliases:
|
||
if alias in text:
|
||
position, citation_text = self._extract_position_and_context(text, alias)
|
||
return {
|
||
"cited": True,
|
||
"confidence": 0.9,
|
||
"match_type": "alias",
|
||
"position": position,
|
||
"citation_text": citation_text,
|
||
}
|
||
|
||
# 3. 模糊匹配
|
||
best_ratio = 0.0
|
||
best_match = None
|
||
for word in self._extract_candidates(text):
|
||
ratio = difflib.SequenceMatcher(None, self.target_brand, word).ratio()
|
||
if ratio > best_ratio:
|
||
best_ratio = ratio
|
||
best_match = word
|
||
|
||
for alias in self.brand_aliases:
|
||
for word in self._extract_candidates(text):
|
||
ratio = difflib.SequenceMatcher(None, alias, word).ratio()
|
||
if ratio > best_ratio:
|
||
best_ratio = ratio
|
||
best_match = word
|
||
|
||
if best_ratio > 0.4 and best_match:
|
||
position, citation_text = self._extract_position_and_context(text, best_match)
|
||
return {
|
||
"cited": True,
|
||
"confidence": round(best_ratio, 2),
|
||
"match_type": "fuzzy",
|
||
"position": position,
|
||
"citation_text": citation_text,
|
||
}
|
||
|
||
return {
|
||
"cited": False,
|
||
"confidence": 0.0,
|
||
"match_type": None,
|
||
"position": None,
|
||
"citation_text": None,
|
||
}
|
||
|
||
def _extract_candidates(self, text: str) -> list[str]:
|
||
"""从文本中提取候选词(按非文字字符分割)"""
|
||
# 匹配中文词组、英文单词等
|
||
return [w for w in re.split(r'[^\w\u4e00-\u9fff]+', text) if len(w) >= 2]
|
||
|
||
def _extract_position_and_context(self, text: str, keyword: str) -> tuple[int | None, str | None]:
|
||
"""提取品牌首次出现的段落位置(1-based)和上下文片段"""
|
||
paragraphs = [p.strip() for p in text.split('\n') if p.strip()]
|
||
if not paragraphs:
|
||
paragraphs = [text]
|
||
|
||
for idx, paragraph in enumerate(paragraphs, start=1):
|
||
if keyword in paragraph:
|
||
# 截取前200字符
|
||
snippet = paragraph[:200]
|
||
return idx, snippet
|
||
|
||
return None, None
|
||
|
||
|
||
class CompetitorDetector:
|
||
"""竞争品牌检测器"""
|
||
|
||
# 预定义一些常见行业品牌列表
|
||
KNOWN_BRANDS = {
|
||
"保险": ["中国平安", "中国人寿", "太平洋保险", "新华保险", "泰康保险", "中国人保", "友邦保险"],
|
||
"金融": ["工商银行", "建设银行", "农业银行", "中国银行", "招商银行", "交通银行"],
|
||
"科技": ["华为", "腾讯", "阿里巴巴", "百度", "字节跳动", "小米", "京东"],
|
||
}
|
||
|
||
def detect(self, text: str, target_brand: str) -> list[str]:
|
||
"""检测文本中出现的其他品牌(排除 target_brand)"""
|
||
if not text:
|
||
return []
|
||
|
||
competitors = set()
|
||
for category, brands in self.KNOWN_BRANDS.items():
|
||
for brand in brands:
|
||
if brand == target_brand:
|
||
continue
|
||
if brand in text:
|
||
competitors.add(brand)
|
||
|
||
return sorted(list(competitors))
|
||
|
||
|
||
class CitationEngine:
|
||
"""引用检测引擎核心"""
|
||
|
||
def __init__(self):
|
||
self.platforms = {
|
||
"wenxin": WenxinAdapter(),
|
||
"kimi": KimiAdapter(),
|
||
"tongyi": TongyiAdapter(),
|
||
"doubao": DoubaoAdapter(),
|
||
"qingyan": QingyanAdapter(),
|
||
"tiangong": TiangongAdapter(),
|
||
"xinghuo": XinghuoAdapter(),
|
||
}
|
||
self.matcher = None
|
||
self.competitor_detector = CompetitorDetector()
|
||
|
||
async def execute_query(self, query: Query, db: AsyncSession) -> list[CitationRecord]:
|
||
"""
|
||
执行一个查询任务:
|
||
1. 创建 BrandMatcher
|
||
2. 遍历 query.platforms
|
||
3. 对每个 platform 执行查询和检测
|
||
4. 更新 query.last_queried_at 和 query.next_query_at
|
||
"""
|
||
self.matcher = BrandMatcher(
|
||
target_brand=query.target_brand,
|
||
brand_aliases=query.brand_aliases or [],
|
||
)
|
||
|
||
records: list[CitationRecord] = []
|
||
platforms = query.platforms or ["wenxin", "kimi"]
|
||
|
||
for platform_name in platforms:
|
||
# 查找或创建 QueryTask
|
||
task = await self._get_or_create_task(db, query.id, platform_name)
|
||
|
||
# 更新状态为 running
|
||
task.status = "running"
|
||
task.started_at = datetime.utcnow()
|
||
task.error_message = None
|
||
await db.commit()
|
||
|
||
try:
|
||
result = await self.execute_single_platform(
|
||
keyword=query.keyword,
|
||
platform=platform_name,
|
||
target_brand=query.target_brand,
|
||
brand_aliases=query.brand_aliases or [],
|
||
)
|
||
|
||
# 创建 CitationRecord
|
||
record = CitationRecord(
|
||
query_id=query.id,
|
||
platform=platform_name,
|
||
cited=result["cited"],
|
||
citation_position=result.get("position"),
|
||
citation_text=result.get("citation_text"),
|
||
competitor_brands=result.get("competitor_brands", []),
|
||
raw_response=_sanitize_raw_response(result.get("raw_response", "")),
|
||
confidence=result.get("confidence"),
|
||
match_type=result.get("match_type"),
|
||
# 引用源分析字段
|
||
data_source=result.get("data_source"),
|
||
source_urls=result.get("source_urls"),
|
||
source_titles=result.get("source_titles"),
|
||
citation_contexts=result.get("citation_contexts"),
|
||
ai_response_text=_sanitize_raw_response(result.get("ai_response_text", "")),
|
||
)
|
||
db.add(record)
|
||
records.append(record)
|
||
|
||
# 更新 QueryTask 状态为 success
|
||
task.status = "success"
|
||
task.completed_at = datetime.utcnow()
|
||
await db.commit()
|
||
|
||
except Exception as e:
|
||
logger.error(f"平台 {platform_name} 查询失败: {e}")
|
||
error_msg = str(e)
|
||
task.status = "failed"
|
||
task.error_message = error_msg
|
||
task.completed_at = datetime.utcnow()
|
||
|
||
# 创建一条 cited=False 的记录作为占位
|
||
record = CitationRecord(
|
||
query_id=query.id,
|
||
platform=platform_name,
|
||
cited=False,
|
||
raw_response=_sanitize_raw_response(error_msg),
|
||
)
|
||
db.add(record)
|
||
records.append(record)
|
||
await db.commit()
|
||
|
||
# 更新 Query 时间字段
|
||
query.last_queried_at = datetime.utcnow()
|
||
query.next_query_at = self._calculate_next_query_at(query.frequency)
|
||
await db.commit()
|
||
|
||
return records
|
||
|
||
async def execute_single_platform(
|
||
self,
|
||
keyword: str,
|
||
platform: str,
|
||
target_brand: str,
|
||
brand_aliases: list,
|
||
) -> dict:
|
||
"""执行单个平台的查询和检测"""
|
||
adapter = self.platforms.get(platform)
|
||
if not adapter:
|
||
raise ValueError(f"不支持的平台: {platform}")
|
||
|
||
# 获取平台内容(将关键词与目标品牌组合,确保结果包含品牌信息)
|
||
search_keyword = f"{keyword} {target_brand}"
|
||
raw_response = await adapter.query(search_keyword)
|
||
|
||
# 引用源分析
|
||
citation_analysis = analyze_citations(raw_response)
|
||
|
||
# 品牌匹配(使用清理后的纯文本进行匹配)
|
||
matcher = BrandMatcher(target_brand=target_brand, brand_aliases=brand_aliases)
|
||
match_result = matcher.match(citation_analysis.clean_response)
|
||
|
||
# 竞争品牌检测
|
||
competitor_brands = self.competitor_detector.detect(
|
||
citation_analysis.clean_response, target_brand
|
||
)
|
||
|
||
# 提取引用源信息
|
||
source_urls = [
|
||
c.source_url for c in citation_analysis.citations if c.source_url
|
||
]
|
||
source_titles = [
|
||
c.source_title for c in citation_analysis.citations if c.source_title
|
||
]
|
||
citation_contexts = [
|
||
c.citation_context for c in citation_analysis.citations if c.citation_context
|
||
]
|
||
|
||
return {
|
||
"cited": match_result["cited"],
|
||
"confidence": match_result["confidence"],
|
||
"match_type": match_result["match_type"],
|
||
"position": match_result["position"],
|
||
"citation_text": match_result["citation_text"],
|
||
"competitor_brands": competitor_brands,
|
||
"raw_response": raw_response,
|
||
# 引用源分析新增字段
|
||
"data_source": citation_analysis.data_source,
|
||
"source_urls": source_urls,
|
||
"source_titles": source_titles,
|
||
"citation_contexts": citation_contexts,
|
||
"ai_response_text": citation_analysis.clean_response,
|
||
}
|
||
|
||
async def _get_or_create_task(
|
||
self, db: AsyncSession, query_id: uuid.UUID, platform: str
|
||
) -> QueryTask:
|
||
"""获取或创建 QueryTask"""
|
||
stmt = select(QueryTask).where(
|
||
QueryTask.query_id == query_id,
|
||
QueryTask.platform == platform,
|
||
)
|
||
result = await db.execute(stmt)
|
||
task = result.scalar_one_or_none()
|
||
|
||
if not task:
|
||
task = QueryTask(
|
||
query_id=query_id,
|
||
platform=platform,
|
||
status="pending",
|
||
)
|
||
db.add(task)
|
||
await db.commit()
|
||
await db.refresh(task)
|
||
|
||
return task
|
||
|
||
def _calculate_next_query_at(self, frequency: str | None) -> datetime:
|
||
"""根据频率计算下次查询时间"""
|
||
now = datetime.utcnow()
|
||
freq_map = {
|
||
"daily": timedelta(days=1),
|
||
"weekly": timedelta(days=7),
|
||
"monthly": timedelta(days=30),
|
||
}
|
||
delta = freq_map.get(frequency or "weekly", timedelta(days=7))
|
||
return now + delta
|
||
|
||
async def close(self):
|
||
"""关闭所有平台适配器"""
|
||
for adapter in self.platforms.values():
|
||
try:
|
||
await adapter.close()
|
||
except Exception as e:
|
||
logger.warning(f"关闭适配器 {adapter.platform_name} 时出错: {e}")
|