geo/backend/app/workers/citation_engine.py

359 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import difflib
import logging
import re
import uuid
from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
def _sanitize_raw_response(text: str | None) -> str:
"""清理原始响应中的无效控制字符,避免 PostgreSQL UTF-8 插入失败"""
if not text:
return ""
# 移除 NULL 字节及其他非法控制字符,保留 \n \t \r
return re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f]", "", text)
from app.models.citation_record import CitationRecord
from app.models.query import Query
from app.models.query_task import QueryTask
from app.workers.platforms.kimi import KimiAdapter
from app.workers.platforms.wenxin import WenxinAdapter
from app.workers.platforms.tongyi import TongyiAdapter
from app.workers.platforms.doubao import DoubaoAdapter
from app.workers.platforms.qingyan import QingyanAdapter
from app.workers.platforms.tiangong import TiangongAdapter
from app.workers.platforms.xinghuo import XinghuoAdapter
from app.workers.citation_extractor import analyze_citations
logger = logging.getLogger(__name__)
class BrandMatcher:
"""品牌匹配器:检测文本中是否引用了目标品牌"""
def __init__(self, target_brand: str, brand_aliases: list[str] | None = None):
self.target_brand = target_brand
self.brand_aliases = brand_aliases or []
def match(self, text: str) -> dict:
"""
返回: {
"cited": bool,
"confidence": float, # 0.0-1.0
"match_type": str, # "exact"/"alias"/"fuzzy"/None
"position": int|None, # 在文本段落中的位置第几段提到1-based
"citation_text": str|None, # 被引用的上下文片段
}
"""
if not text:
return {
"cited": False,
"confidence": 0.0,
"match_type": None,
"position": None,
"citation_text": None,
}
# 1. 精确匹配
if self.target_brand in text:
position, citation_text = self._extract_position_and_context(text, self.target_brand)
return {
"cited": True,
"confidence": 1.0,
"match_type": "exact",
"position": position,
"citation_text": citation_text,
}
# 2. 别名匹配
for alias in self.brand_aliases:
if alias in text:
position, citation_text = self._extract_position_and_context(text, alias)
return {
"cited": True,
"confidence": 0.9,
"match_type": "alias",
"position": position,
"citation_text": citation_text,
}
# 3. 模糊匹配
best_ratio = 0.0
best_match = None
for word in self._extract_candidates(text):
ratio = difflib.SequenceMatcher(None, self.target_brand, word).ratio()
if ratio > best_ratio:
best_ratio = ratio
best_match = word
for alias in self.brand_aliases:
for word in self._extract_candidates(text):
ratio = difflib.SequenceMatcher(None, alias, word).ratio()
if ratio > best_ratio:
best_ratio = ratio
best_match = word
if best_ratio > 0.4 and best_match:
position, citation_text = self._extract_position_and_context(text, best_match)
return {
"cited": True,
"confidence": round(best_ratio, 2),
"match_type": "fuzzy",
"position": position,
"citation_text": citation_text,
}
return {
"cited": False,
"confidence": 0.0,
"match_type": None,
"position": None,
"citation_text": None,
}
def _extract_candidates(self, text: str) -> list[str]:
"""从文本中提取候选词(按非文字字符分割)"""
# 匹配中文词组、英文单词等
return [w for w in re.split(r'[^\w\u4e00-\u9fff]+', text) if len(w) >= 2]
def _extract_position_and_context(self, text: str, keyword: str) -> tuple[int | None, str | None]:
"""提取品牌首次出现的段落位置1-based和上下文片段"""
paragraphs = [p.strip() for p in text.split('\n') if p.strip()]
if not paragraphs:
paragraphs = [text]
for idx, paragraph in enumerate(paragraphs, start=1):
if keyword in paragraph:
# 截取前200字符
snippet = paragraph[:200]
return idx, snippet
return None, None
class CompetitorDetector:
"""竞争品牌检测器"""
# 预定义一些常见行业品牌列表
KNOWN_BRANDS = {
"保险": ["中国平安", "中国人寿", "太平洋保险", "新华保险", "泰康保险", "中国人保", "友邦保险"],
"金融": ["工商银行", "建设银行", "农业银行", "中国银行", "招商银行", "交通银行"],
"科技": ["华为", "腾讯", "阿里巴巴", "百度", "字节跳动", "小米", "京东"],
}
def detect(self, text: str, target_brand: str) -> list[str]:
"""检测文本中出现的其他品牌(排除 target_brand"""
if not text:
return []
competitors = set()
for category, brands in self.KNOWN_BRANDS.items():
for brand in brands:
if brand == target_brand:
continue
if brand in text:
competitors.add(brand)
return sorted(list(competitors))
class CitationEngine:
"""引用检测引擎核心"""
def __init__(self):
self.platforms = {
"wenxin": WenxinAdapter(),
"kimi": KimiAdapter(),
"tongyi": TongyiAdapter(),
"doubao": DoubaoAdapter(),
"qingyan": QingyanAdapter(),
"tiangong": TiangongAdapter(),
"xinghuo": XinghuoAdapter(),
}
self.matcher = None
self.competitor_detector = CompetitorDetector()
async def execute_query(self, query: Query, db: AsyncSession) -> list[CitationRecord]:
"""
执行一个查询任务:
1. 创建 BrandMatcher
2. 遍历 query.platforms
3. 对每个 platform 执行查询和检测
4. 更新 query.last_queried_at 和 query.next_query_at
"""
self.matcher = BrandMatcher(
target_brand=query.target_brand,
brand_aliases=query.brand_aliases or [],
)
records: list[CitationRecord] = []
platforms = query.platforms or ["wenxin", "kimi"]
for platform_name in platforms:
# 查找或创建 QueryTask
task = await self._get_or_create_task(db, query.id, platform_name)
# 更新状态为 running
task.status = "running"
task.started_at = datetime.utcnow()
task.error_message = None
await db.commit()
try:
result = await self.execute_single_platform(
keyword=query.keyword,
platform=platform_name,
target_brand=query.target_brand,
brand_aliases=query.brand_aliases or [],
)
# 创建 CitationRecord
record = CitationRecord(
query_id=query.id,
platform=platform_name,
cited=result["cited"],
citation_position=result.get("position"),
citation_text=result.get("citation_text"),
competitor_brands=result.get("competitor_brands", []),
raw_response=_sanitize_raw_response(result.get("raw_response", "")),
confidence=result.get("confidence"),
match_type=result.get("match_type"),
# 引用源分析字段
data_source=result.get("data_source"),
source_urls=result.get("source_urls"),
source_titles=result.get("source_titles"),
citation_contexts=result.get("citation_contexts"),
ai_response_text=_sanitize_raw_response(result.get("ai_response_text", "")),
)
db.add(record)
records.append(record)
# 更新 QueryTask 状态为 success
task.status = "success"
task.completed_at = datetime.utcnow()
await db.commit()
except Exception as e:
logger.error(f"平台 {platform_name} 查询失败: {e}")
error_msg = str(e)
task.status = "failed"
task.error_message = error_msg
task.completed_at = datetime.utcnow()
# 创建一条 cited=False 的记录作为占位
record = CitationRecord(
query_id=query.id,
platform=platform_name,
cited=False,
raw_response=_sanitize_raw_response(error_msg),
)
db.add(record)
records.append(record)
await db.commit()
# 更新 Query 时间字段
query.last_queried_at = datetime.utcnow()
query.next_query_at = self._calculate_next_query_at(query.frequency)
await db.commit()
return records
async def execute_single_platform(
self,
keyword: str,
platform: str,
target_brand: str,
brand_aliases: list,
) -> dict:
"""执行单个平台的查询和检测"""
adapter = self.platforms.get(platform)
if not adapter:
raise ValueError(f"不支持的平台: {platform}")
# 获取平台内容(将关键词与目标品牌组合,确保结果包含品牌信息)
search_keyword = f"{keyword} {target_brand}"
raw_response = await adapter.query(search_keyword)
# 引用源分析
citation_analysis = analyze_citations(raw_response)
# 品牌匹配(使用清理后的纯文本进行匹配)
matcher = BrandMatcher(target_brand=target_brand, brand_aliases=brand_aliases)
match_result = matcher.match(citation_analysis.clean_response)
# 竞争品牌检测
competitor_brands = self.competitor_detector.detect(
citation_analysis.clean_response, target_brand
)
# 提取引用源信息
source_urls = [
c.source_url for c in citation_analysis.citations if c.source_url
]
source_titles = [
c.source_title for c in citation_analysis.citations if c.source_title
]
citation_contexts = [
c.citation_context for c in citation_analysis.citations if c.citation_context
]
return {
"cited": match_result["cited"],
"confidence": match_result["confidence"],
"match_type": match_result["match_type"],
"position": match_result["position"],
"citation_text": match_result["citation_text"],
"competitor_brands": competitor_brands,
"raw_response": raw_response,
# 引用源分析新增字段
"data_source": citation_analysis.data_source,
"source_urls": source_urls,
"source_titles": source_titles,
"citation_contexts": citation_contexts,
"ai_response_text": citation_analysis.clean_response,
}
async def _get_or_create_task(
self, db: AsyncSession, query_id: uuid.UUID, platform: str
) -> QueryTask:
"""获取或创建 QueryTask"""
stmt = select(QueryTask).where(
QueryTask.query_id == query_id,
QueryTask.platform == platform,
)
result = await db.execute(stmt)
task = result.scalar_one_or_none()
if not task:
task = QueryTask(
query_id=query_id,
platform=platform,
status="pending",
)
db.add(task)
await db.commit()
await db.refresh(task)
return task
def _calculate_next_query_at(self, frequency: str | None) -> datetime:
"""根据频率计算下次查询时间"""
now = datetime.utcnow()
freq_map = {
"daily": timedelta(days=1),
"weekly": timedelta(days=7),
"monthly": timedelta(days=30),
}
delta = freq_map.get(frequency or "weekly", timedelta(days=7))
return now + delta
async def close(self):
"""关闭所有平台适配器"""
for adapter in self.platforms.values():
try:
await adapter.close()
except Exception as e:
logger.warning(f"关闭适配器 {adapter.platform_name} 时出错: {e}")