geo/backend/app/services/ai_engine/platform_bridge.py

296 lines
9.5 KiB
Python

import logging
import os
import re
import time
from urllib.parse import quote
from datetime import UTC, datetime
import httpx
from .base import AIEngineAdapter, AIQueryResult, EngineType
logger = logging.getLogger(__name__)
_PLATFORM_NAME_MAP: dict[str, EngineType] = {
"wenxin": EngineType.WENXIN,
"kimi": EngineType.KIMI,
"doubao": EngineType.DOUBAO,
"tongyi": EngineType.QWEN,
"deepseek": EngineType.DEEPSEEK,
"chatgpt": EngineType.CHATGPT,
"perplexity": EngineType.PERPLEXITY,
"gemini": EngineType.GEMINI,
"yuanbao": EngineType.YUANBAO,
}
_SEARCH_ONLY_PLATFORMS = {"qingyan", "tiangong", "xinghuo"}
def get_engine_type_for_platform(platform_name: str) -> EngineType | None:
return _PLATFORM_NAME_MAP.get(platform_name)
def is_search_only_platform(platform_name: str) -> bool:
return platform_name in _SEARCH_ONLY_PLATFORMS
async def search_wikipedia(keyword: str, max_chars: int = 2000) -> str:
search_url = "https://zh.wikipedia.org/w/api.php"
headers = {
"User-Agent": "GEO-Citation-Bot/1.0 (contact@example.com)",
}
async with httpx.AsyncClient(timeout=30) as client:
search_resp = await client.get(
search_url,
headers=headers,
params={
"action": "query",
"list": "search",
"srsearch": keyword,
"srlimit": 3,
"format": "json",
"origin": "*",
},
)
search_resp.raise_for_status()
search_data = search_resp.json()
search_results = search_data.get("query", {}).get("search", [])
if not search_results:
return ""
title = search_results[0]["title"]
async with httpx.AsyncClient(timeout=30) as client:
extract_resp = await client.get(
search_url,
headers=headers,
params={
"action": "query",
"prop": "extracts",
"titles": title,
"explaintext": True,
"exsentences": 15,
"format": "json",
"origin": "*",
},
)
extract_resp.raise_for_status()
extract_data = extract_resp.json()
pages = extract_data.get("query", {}).get("pages", {})
for page in pages.values():
extract = page.get("extract", "")
if extract:
extract = re.sub(r'\[\d+\]', '', extract)
extract = re.sub(r'\s+', ' ', extract).strip()
return extract[:max_chars]
return ""
def _strip_html(raw: str) -> str:
raw = raw.replace(" ", " ")
raw = raw.replace(""", '"')
raw = raw.replace("&", "&")
raw = raw.replace("&lt;", "<")
raw = raw.replace("&gt;", ">")
raw = raw.replace("&#39;", "'")
text = re.sub(r"<[^>]+>", "", raw)
text = re.sub(r"\s+", " ", text).strip()
return text
async def search_duckduckgo(query: str, max_results: int = 5) -> str:
url = f"https://html.duckduckgo.com/html/?q={quote(query)}"
headers = {
"User-Agent": (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
),
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"Accept-Language": "zh-CN,zh;q=0.9,en-US;q=0.8,en;q=0.7",
}
try:
async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
resp = await client.get(url, headers=headers)
resp.raise_for_status()
html = resp.text
if "web-result" not in html and "result__snippet" not in html and "result__title" not in html:
raise RuntimeError("DuckDuckGo 返回了非结果页面")
results: list[str] = []
result_blocks = re.findall(
r'<div class="result[^"]*"[^>]*>.*?<h[^>]*class="result__title"[^>]*>.*?<a[^>]*>(.*?)</a>.*?</h[^>]*>.*?<a[^>]*class="result__snippet"[^>]*>(.*?)</a>.*?</div>',
html,
re.DOTALL | re.IGNORECASE,
)
if result_blocks:
for title_raw, snippet_raw in result_blocks[:max_results]:
title = _strip_html(title_raw)
snippet = _strip_html(snippet_raw)
if title or snippet:
results.append(f"{title}\n{snippet}")
if not results:
snippets = re.findall(
r'<a[^>]*class="result__snippet"[^>]*>(.*?)</a>', html, re.DOTALL | re.IGNORECASE
)
titles = re.findall(
r'<h[^>]*class="result__title"[^>]*>.*?<a[^>]*>(.*?)</a>.*?</h[^>]*>',
html,
re.DOTALL | re.IGNORECASE,
)
for i in range(min(len(titles), len(snippets), max_results)):
title = _strip_html(titles[i])
snippet = _strip_html(snippets[i])
if title or snippet:
results.append(f"{title}\n{snippet}")
if results:
return "\n\n".join(results)
raise RuntimeError("DuckDuckGo 未解析到结果")
except Exception as e:
logger.warning(f"DuckDuckGo 搜索失败: {e},回退到 Wikipedia")
wiki_text = await search_wikipedia(query, max_chars=2000)
if wiki_text:
return wiki_text
raise RuntimeError(f"所有搜索源均失败: {e}")
async def fetch_search_content(platform_name: str, keyword: str) -> str:
logger.info(f"[{platform_name}] 搜索查询: {keyword}")
return await search_duckduckgo(keyword, max_results=5)
class SearchOnlyAdapter(AIEngineAdapter):
def __init__(self, platform_name: str, **kwargs):
self._platform_name = platform_name
super().__init__(**kwargs)
def get_engine_type(self) -> EngineType:
return EngineType.DEEPSEEK
def _get_env_key(self) -> str | None:
return ""
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
start_time = time.perf_counter()
content = await fetch_search_content(self._platform_name, query)
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
content, brand_name, competitor_names
)
return AIQueryResult(
engine_type=self.get_engine_type(),
query=query,
raw_response=content,
citations=[],
has_brand_citation=has_brand,
has_competitor_citation=has_comp,
brand_context=brand_ctx,
competitor_contexts=comp_ctx,
response_time_ms=elapsed_ms,
timestamp=datetime.now(UTC),
metadata={"platform_name": self._platform_name, "mode": "search_only"},
)
async def query_platform_raw(
platform_name: str,
keyword: str,
brand_name: str = "",
competitor_names: list[str] | None = None,
) -> str:
from .batch_query import _build_adapters
if is_search_only_platform(platform_name):
content = await fetch_search_content(platform_name, keyword)
return f"[data_source: search_engine]\n{content}"
engine_type = get_engine_type_for_platform(platform_name)
if engine_type is None:
raise ValueError(f"不支持的平台: {platform_name}")
adapters = _build_adapters()
adapter = adapters.get(engine_type.value)
if adapter is None:
raise ValueError(f"平台 {platform_name} 适配器未注册")
result = await adapter.query(keyword, brand_name, competitor_names)
return f"[data_source: ai_platform]\n{result.raw_response}"
_SUPPORTED_PLATFORMS = {
"wenxin", "kimi", "doubao", "tongyi",
"qingyan", "tiangong", "xinghuo",
}
async def execute_single_platform(
keyword: str,
platform: str,
target_brand: str,
brand_aliases: list,
) -> dict:
if platform not in _SUPPORTED_PLATFORMS:
raise ValueError(f"不支持的平台: {platform}")
from app.workers.citation_extractor import analyze_citations
search_keyword = f"{keyword} {target_brand}"
raw_response = await query_platform_raw(
platform_name=platform,
keyword=search_keyword,
brand_name=target_brand,
)
citation_analysis = analyze_citations(raw_response)
from app.workers.citation_engine import BrandMatcher, CompetitorDetector
matcher = BrandMatcher(target_brand=target_brand, brand_aliases=brand_aliases)
match_result = matcher.match(citation_analysis.clean_response)
competitor_detector = CompetitorDetector()
competitor_brands = competitor_detector.detect(
citation_analysis.clean_response, target_brand
)
source_urls = [
c.source_url for c in citation_analysis.citations if c.source_url
]
source_titles = [
c.source_title for c in citation_analysis.citations if c.source_title
]
citation_contexts = [
c.citation_context for c in citation_analysis.citations if c.citation_context
]
return {
"cited": match_result["cited"],
"confidence": match_result["confidence"],
"match_type": match_result["match_type"],
"position": match_result["position"],
"citation_text": match_result["citation_text"],
"competitor_brands": competitor_brands,
"raw_response": raw_response,
"data_source": citation_analysis.data_source,
"source_urls": source_urls,
"source_titles": source_titles,
"citation_contexts": citation_contexts,
"ai_response_text": citation_analysis.clean_response,
}