geo/backend/app/workers/citation_engine.py

282 lines
9.3 KiB
Python

import difflib
import logging
import re
import uuid
from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.citation_record import CitationRecord
from app.models.query import Query
from app.models.query_task import QueryTask
from app.services.ai_engine.platform_bridge import query_platform_raw
from app.workers.citation_extractor import analyze_citations
logger = logging.getLogger(__name__)
class BrandMatcher:
def __init__(self, target_brand: str, brand_aliases: list[str] | None = None):
self.target_brand = target_brand
self.brand_aliases = brand_aliases or []
def match(self, text: str) -> dict:
if not text:
return {
"cited": False,
"confidence": 0.0,
"match_type": None,
"position": None,
"citation_text": None,
}
if self.target_brand in text:
position, citation_text = self._extract_position_and_context(text, self.target_brand)
return {
"cited": True,
"confidence": 1.0,
"match_type": "exact",
"position": position,
"citation_text": citation_text,
}
for alias in self.brand_aliases:
if alias in text:
position, citation_text = self._extract_position_and_context(text, alias)
return {
"cited": True,
"confidence": 0.9,
"match_type": "alias",
"position": position,
"citation_text": citation_text,
}
best_ratio = 0.0
best_match = None
for word in self._extract_candidates(text):
ratio = difflib.SequenceMatcher(None, self.target_brand, word).ratio()
if ratio > best_ratio:
best_ratio = ratio
best_match = word
for alias in self.brand_aliases:
for word in self._extract_candidates(text):
ratio = difflib.SequenceMatcher(None, alias, word).ratio()
if ratio > best_ratio:
best_ratio = ratio
best_match = word
if best_ratio > 0.4 and best_match:
position, citation_text = self._extract_position_and_context(text, best_match)
return {
"cited": True,
"confidence": round(best_ratio, 2),
"match_type": "fuzzy",
"position": position,
"citation_text": citation_text,
}
return {
"cited": False,
"confidence": 0.0,
"match_type": None,
"position": None,
"citation_text": None,
}
def _extract_candidates(self, text: str) -> list[str]:
return [w for w in re.split(r'[^\w\u4e00-\u9fff]+', text) if len(w) >= 2]
def _extract_position_and_context(self, text: str, keyword: str) -> tuple[int | None, str | None]:
paragraphs = [p.strip() for p in text.split('\n') if p.strip()]
if not paragraphs:
paragraphs = [text]
for idx, paragraph in enumerate(paragraphs, start=1):
if keyword in paragraph:
snippet = paragraph[:200]
return idx, snippet
return None, None
class CompetitorDetector:
KNOWN_BRANDS = {
"保险": ["中国平安", "中国人寿", "太平洋保险", "新华保险", "泰康保险", "中国人保", "友邦保险"],
"金融": ["工商银行", "建设银行", "农业银行", "中国银行", "招商银行", "交通银行"],
"科技": ["华为", "腾讯", "阿里巴巴", "百度", "字节跳动", "小米", "京东"],
}
def detect(self, text: str, target_brand: str) -> list[str]:
if not text:
return []
competitors = set()
for category, brands in self.KNOWN_BRANDS.items():
for brand in brands:
if brand == target_brand:
continue
if brand in text:
competitors.add(brand)
return sorted(list(competitors))
class CitationEngine:
def __init__(self):
self._supported_platforms = {
"wenxin", "kimi", "doubao", "tongyi",
"qingyan", "tiangong", "xinghuo",
}
self.matcher = None
self.competitor_detector = CompetitorDetector()
async def execute_query(self, query: Query, db: AsyncSession) -> list[CitationRecord]:
self.matcher = BrandMatcher(
target_brand=query.target_brand,
brand_aliases=query.brand_aliases or [],
)
records: list[CitationRecord] = []
platforms = query.platforms or ["wenxin", "kimi"]
for platform_name in platforms:
task = await self._get_or_create_task(db, query.id, platform_name)
task.status = "running"
task.started_at = datetime.utcnow()
task.error_message = None
await db.commit()
try:
result = await self.execute_single_platform(
keyword=query.keyword,
platform=platform_name,
target_brand=query.target_brand,
brand_aliases=query.brand_aliases or [],
)
record = CitationRecord.from_citation_result(
query_id=query.id,
platform=platform_name,
result=result,
)
db.add(record)
records.append(record)
task.status = "success"
task.completed_at = datetime.utcnow()
await db.commit()
except Exception as e:
logger.error(f"平台 {platform_name} 查询失败: {e}")
error_msg = str(e)
task.status = "failed"
task.error_message = error_msg
task.completed_at = datetime.utcnow()
record = CitationRecord.from_citation_result(
query_id=query.id,
platform=platform_name,
result={"cited": False, "raw_response": error_msg},
)
db.add(record)
records.append(record)
await db.commit()
query.last_queried_at = datetime.utcnow()
query.next_query_at = self._calculate_next_query_at(query.frequency)
await db.commit()
return records
async def execute_single_platform(
self,
keyword: str,
platform: str,
target_brand: str,
brand_aliases: list,
) -> dict:
if platform not in self._supported_platforms:
raise ValueError(f"不支持的平台: {platform}")
search_keyword = f"{keyword} {target_brand}"
raw_response = await query_platform_raw(
platform_name=platform,
keyword=search_keyword,
brand_name=target_brand,
)
citation_analysis = analyze_citations(raw_response)
matcher = BrandMatcher(target_brand=target_brand, brand_aliases=brand_aliases)
match_result = matcher.match(citation_analysis.clean_response)
competitor_brands = self.competitor_detector.detect(
citation_analysis.clean_response, target_brand
)
source_urls = [
c.source_url for c in citation_analysis.citations if c.source_url
]
source_titles = [
c.source_title for c in citation_analysis.citations if c.source_title
]
citation_contexts = [
c.citation_context for c in citation_analysis.citations if c.citation_context
]
return {
"cited": match_result["cited"],
"confidence": match_result["confidence"],
"match_type": match_result["match_type"],
"position": match_result["position"],
"citation_text": match_result["citation_text"],
"competitor_brands": competitor_brands,
"raw_response": raw_response,
"data_source": citation_analysis.data_source,
"source_urls": source_urls,
"source_titles": source_titles,
"citation_contexts": citation_contexts,
"ai_response_text": citation_analysis.clean_response,
}
async def _get_or_create_task(
self, db: AsyncSession, query_id: uuid.UUID, platform: str
) -> QueryTask:
stmt = select(QueryTask).where(
QueryTask.query_id == query_id,
QueryTask.platform == platform,
)
result = await db.execute(stmt)
task = result.scalar_one_or_none()
if not task:
task = QueryTask(
query_id=query_id,
platform=platform,
status="pending",
)
db.add(task)
await db.commit()
await db.refresh(task)
return task
def _calculate_next_query_at(self, frequency: str | None) -> datetime:
now = datetime.utcnow()
freq_map = {
"daily": timedelta(days=1),
"weekly": timedelta(days=7),
"monthly": timedelta(days=30),
}
delta = freq_map.get(frequency or "weekly", timedelta(days=7))
return now + delta
async def close(self):
pass