282 lines
9.3 KiB
Python
282 lines
9.3 KiB
Python
import difflib
|
|
import logging
|
|
import re
|
|
import uuid
|
|
from datetime import datetime, timedelta
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
|
|
from app.models.citation_record import CitationRecord
|
|
from app.models.query import Query
|
|
from app.models.query_task import QueryTask
|
|
from app.services.ai_engine.platform_bridge import query_platform_raw
|
|
from app.workers.citation_extractor import analyze_citations
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BrandMatcher:
|
|
|
|
def __init__(self, target_brand: str, brand_aliases: list[str] | None = None):
|
|
self.target_brand = target_brand
|
|
self.brand_aliases = brand_aliases or []
|
|
|
|
def match(self, text: str) -> dict:
|
|
if not text:
|
|
return {
|
|
"cited": False,
|
|
"confidence": 0.0,
|
|
"match_type": None,
|
|
"position": None,
|
|
"citation_text": None,
|
|
}
|
|
|
|
if self.target_brand in text:
|
|
position, citation_text = self._extract_position_and_context(text, self.target_brand)
|
|
return {
|
|
"cited": True,
|
|
"confidence": 1.0,
|
|
"match_type": "exact",
|
|
"position": position,
|
|
"citation_text": citation_text,
|
|
}
|
|
|
|
for alias in self.brand_aliases:
|
|
if alias in text:
|
|
position, citation_text = self._extract_position_and_context(text, alias)
|
|
return {
|
|
"cited": True,
|
|
"confidence": 0.9,
|
|
"match_type": "alias",
|
|
"position": position,
|
|
"citation_text": citation_text,
|
|
}
|
|
|
|
best_ratio = 0.0
|
|
best_match = None
|
|
for word in self._extract_candidates(text):
|
|
ratio = difflib.SequenceMatcher(None, self.target_brand, word).ratio()
|
|
if ratio > best_ratio:
|
|
best_ratio = ratio
|
|
best_match = word
|
|
|
|
for alias in self.brand_aliases:
|
|
for word in self._extract_candidates(text):
|
|
ratio = difflib.SequenceMatcher(None, alias, word).ratio()
|
|
if ratio > best_ratio:
|
|
best_ratio = ratio
|
|
best_match = word
|
|
|
|
if best_ratio > 0.4 and best_match:
|
|
position, citation_text = self._extract_position_and_context(text, best_match)
|
|
return {
|
|
"cited": True,
|
|
"confidence": round(best_ratio, 2),
|
|
"match_type": "fuzzy",
|
|
"position": position,
|
|
"citation_text": citation_text,
|
|
}
|
|
|
|
return {
|
|
"cited": False,
|
|
"confidence": 0.0,
|
|
"match_type": None,
|
|
"position": None,
|
|
"citation_text": None,
|
|
}
|
|
|
|
def _extract_candidates(self, text: str) -> list[str]:
|
|
return [w for w in re.split(r'[^\w\u4e00-\u9fff]+', text) if len(w) >= 2]
|
|
|
|
def _extract_position_and_context(self, text: str, keyword: str) -> tuple[int | None, str | None]:
|
|
paragraphs = [p.strip() for p in text.split('\n') if p.strip()]
|
|
if not paragraphs:
|
|
paragraphs = [text]
|
|
|
|
for idx, paragraph in enumerate(paragraphs, start=1):
|
|
if keyword in paragraph:
|
|
snippet = paragraph[:200]
|
|
return idx, snippet
|
|
|
|
return None, None
|
|
|
|
|
|
class CompetitorDetector:
|
|
|
|
KNOWN_BRANDS = {
|
|
"保险": ["中国平安", "中国人寿", "太平洋保险", "新华保险", "泰康保险", "中国人保", "友邦保险"],
|
|
"金融": ["工商银行", "建设银行", "农业银行", "中国银行", "招商银行", "交通银行"],
|
|
"科技": ["华为", "腾讯", "阿里巴巴", "百度", "字节跳动", "小米", "京东"],
|
|
}
|
|
|
|
def detect(self, text: str, target_brand: str) -> list[str]:
|
|
if not text:
|
|
return []
|
|
|
|
competitors = set()
|
|
for category, brands in self.KNOWN_BRANDS.items():
|
|
for brand in brands:
|
|
if brand == target_brand:
|
|
continue
|
|
if brand in text:
|
|
competitors.add(brand)
|
|
|
|
return sorted(list(competitors))
|
|
|
|
|
|
class CitationEngine:
|
|
|
|
def __init__(self):
|
|
self._supported_platforms = {
|
|
"wenxin", "kimi", "doubao", "tongyi",
|
|
"qingyan", "tiangong", "xinghuo",
|
|
}
|
|
self.matcher = None
|
|
self.competitor_detector = CompetitorDetector()
|
|
|
|
async def execute_query(self, query: Query, db: AsyncSession) -> list[CitationRecord]:
|
|
self.matcher = BrandMatcher(
|
|
target_brand=query.target_brand,
|
|
brand_aliases=query.brand_aliases or [],
|
|
)
|
|
|
|
records: list[CitationRecord] = []
|
|
platforms = query.platforms or ["wenxin", "kimi"]
|
|
|
|
for platform_name in platforms:
|
|
task = await self._get_or_create_task(db, query.id, platform_name)
|
|
|
|
task.status = "running"
|
|
task.started_at = datetime.utcnow()
|
|
task.error_message = None
|
|
await db.commit()
|
|
|
|
try:
|
|
result = await self.execute_single_platform(
|
|
keyword=query.keyword,
|
|
platform=platform_name,
|
|
target_brand=query.target_brand,
|
|
brand_aliases=query.brand_aliases or [],
|
|
)
|
|
|
|
record = CitationRecord.from_citation_result(
|
|
query_id=query.id,
|
|
platform=platform_name,
|
|
result=result,
|
|
)
|
|
db.add(record)
|
|
records.append(record)
|
|
|
|
task.status = "success"
|
|
task.completed_at = datetime.utcnow()
|
|
await db.commit()
|
|
|
|
except Exception as e:
|
|
logger.error(f"平台 {platform_name} 查询失败: {e}")
|
|
error_msg = str(e)
|
|
task.status = "failed"
|
|
task.error_message = error_msg
|
|
task.completed_at = datetime.utcnow()
|
|
|
|
record = CitationRecord.from_citation_result(
|
|
query_id=query.id,
|
|
platform=platform_name,
|
|
result={"cited": False, "raw_response": error_msg},
|
|
)
|
|
db.add(record)
|
|
records.append(record)
|
|
await db.commit()
|
|
|
|
query.last_queried_at = datetime.utcnow()
|
|
query.next_query_at = self._calculate_next_query_at(query.frequency)
|
|
await db.commit()
|
|
|
|
return records
|
|
|
|
async def execute_single_platform(
|
|
self,
|
|
keyword: str,
|
|
platform: str,
|
|
target_brand: str,
|
|
brand_aliases: list,
|
|
) -> dict:
|
|
if platform not in self._supported_platforms:
|
|
raise ValueError(f"不支持的平台: {platform}")
|
|
|
|
search_keyword = f"{keyword} {target_brand}"
|
|
raw_response = await query_platform_raw(
|
|
platform_name=platform,
|
|
keyword=search_keyword,
|
|
brand_name=target_brand,
|
|
)
|
|
|
|
citation_analysis = analyze_citations(raw_response)
|
|
|
|
matcher = BrandMatcher(target_brand=target_brand, brand_aliases=brand_aliases)
|
|
match_result = matcher.match(citation_analysis.clean_response)
|
|
|
|
competitor_brands = self.competitor_detector.detect(
|
|
citation_analysis.clean_response, target_brand
|
|
)
|
|
|
|
source_urls = [
|
|
c.source_url for c in citation_analysis.citations if c.source_url
|
|
]
|
|
source_titles = [
|
|
c.source_title for c in citation_analysis.citations if c.source_title
|
|
]
|
|
citation_contexts = [
|
|
c.citation_context for c in citation_analysis.citations if c.citation_context
|
|
]
|
|
|
|
return {
|
|
"cited": match_result["cited"],
|
|
"confidence": match_result["confidence"],
|
|
"match_type": match_result["match_type"],
|
|
"position": match_result["position"],
|
|
"citation_text": match_result["citation_text"],
|
|
"competitor_brands": competitor_brands,
|
|
"raw_response": raw_response,
|
|
"data_source": citation_analysis.data_source,
|
|
"source_urls": source_urls,
|
|
"source_titles": source_titles,
|
|
"citation_contexts": citation_contexts,
|
|
"ai_response_text": citation_analysis.clean_response,
|
|
}
|
|
|
|
async def _get_or_create_task(
|
|
self, db: AsyncSession, query_id: uuid.UUID, platform: str
|
|
) -> QueryTask:
|
|
stmt = select(QueryTask).where(
|
|
QueryTask.query_id == query_id,
|
|
QueryTask.platform == platform,
|
|
)
|
|
result = await db.execute(stmt)
|
|
task = result.scalar_one_or_none()
|
|
|
|
if not task:
|
|
task = QueryTask(
|
|
query_id=query_id,
|
|
platform=platform,
|
|
status="pending",
|
|
)
|
|
db.add(task)
|
|
await db.commit()
|
|
await db.refresh(task)
|
|
|
|
return task
|
|
|
|
def _calculate_next_query_at(self, frequency: str | None) -> datetime:
|
|
now = datetime.utcnow()
|
|
freq_map = {
|
|
"daily": timedelta(days=1),
|
|
"weekly": timedelta(days=7),
|
|
"monthly": timedelta(days=30),
|
|
}
|
|
delta = freq_map.get(frequency or "weekly", timedelta(days=7))
|
|
return now + delta
|
|
|
|
async def close(self):
|
|
pass
|