296 lines
9.5 KiB
Python
296 lines
9.5 KiB
Python
import logging
|
|
import os
|
|
import re
|
|
import time
|
|
from urllib.parse import quote
|
|
from datetime import UTC, datetime
|
|
|
|
import httpx
|
|
|
|
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_PLATFORM_NAME_MAP: dict[str, EngineType] = {
|
|
"wenxin": EngineType.WENXIN,
|
|
"kimi": EngineType.KIMI,
|
|
"doubao": EngineType.DOUBAO,
|
|
"tongyi": EngineType.QWEN,
|
|
"deepseek": EngineType.DEEPSEEK,
|
|
"chatgpt": EngineType.CHATGPT,
|
|
"perplexity": EngineType.PERPLEXITY,
|
|
"gemini": EngineType.GEMINI,
|
|
"yuanbao": EngineType.YUANBAO,
|
|
}
|
|
|
|
_SEARCH_ONLY_PLATFORMS = {"qingyan", "tiangong", "xinghuo"}
|
|
|
|
|
|
def get_engine_type_for_platform(platform_name: str) -> EngineType | None:
|
|
return _PLATFORM_NAME_MAP.get(platform_name)
|
|
|
|
|
|
def is_search_only_platform(platform_name: str) -> bool:
|
|
return platform_name in _SEARCH_ONLY_PLATFORMS
|
|
|
|
|
|
async def search_wikipedia(keyword: str, max_chars: int = 2000) -> str:
|
|
search_url = "https://zh.wikipedia.org/w/api.php"
|
|
headers = {
|
|
"User-Agent": "GEO-Citation-Bot/1.0 (contact@example.com)",
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
search_resp = await client.get(
|
|
search_url,
|
|
headers=headers,
|
|
params={
|
|
"action": "query",
|
|
"list": "search",
|
|
"srsearch": keyword,
|
|
"srlimit": 3,
|
|
"format": "json",
|
|
"origin": "*",
|
|
},
|
|
)
|
|
search_resp.raise_for_status()
|
|
search_data = search_resp.json()
|
|
|
|
search_results = search_data.get("query", {}).get("search", [])
|
|
if not search_results:
|
|
return ""
|
|
|
|
title = search_results[0]["title"]
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
extract_resp = await client.get(
|
|
search_url,
|
|
headers=headers,
|
|
params={
|
|
"action": "query",
|
|
"prop": "extracts",
|
|
"titles": title,
|
|
"explaintext": True,
|
|
"exsentences": 15,
|
|
"format": "json",
|
|
"origin": "*",
|
|
},
|
|
)
|
|
extract_resp.raise_for_status()
|
|
extract_data = extract_resp.json()
|
|
|
|
pages = extract_data.get("query", {}).get("pages", {})
|
|
for page in pages.values():
|
|
extract = page.get("extract", "")
|
|
if extract:
|
|
extract = re.sub(r'\[\d+\]', '', extract)
|
|
extract = re.sub(r'\s+', ' ', extract).strip()
|
|
return extract[:max_chars]
|
|
|
|
return ""
|
|
|
|
|
|
def _strip_html(raw: str) -> str:
|
|
raw = raw.replace(" ", " ")
|
|
raw = raw.replace(""", '"')
|
|
raw = raw.replace("&", "&")
|
|
raw = raw.replace("<", "<")
|
|
raw = raw.replace(">", ">")
|
|
raw = raw.replace("'", "'")
|
|
text = re.sub(r"<[^>]+>", "", raw)
|
|
text = re.sub(r"\s+", " ", text).strip()
|
|
return text
|
|
|
|
|
|
async def search_duckduckgo(query: str, max_results: int = 5) -> str:
|
|
url = f"https://html.duckduckgo.com/html/?q={quote(query)}"
|
|
headers = {
|
|
"User-Agent": (
|
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
|
"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
|
|
),
|
|
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
|
"Accept-Language": "zh-CN,zh;q=0.9,en-US;q=0.8,en;q=0.7",
|
|
}
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
|
|
resp = await client.get(url, headers=headers)
|
|
resp.raise_for_status()
|
|
html = resp.text
|
|
|
|
if "web-result" not in html and "result__snippet" not in html and "result__title" not in html:
|
|
raise RuntimeError("DuckDuckGo 返回了非结果页面")
|
|
|
|
results: list[str] = []
|
|
|
|
result_blocks = re.findall(
|
|
r'<div class="result[^"]*"[^>]*>.*?<h[^>]*class="result__title"[^>]*>.*?<a[^>]*>(.*?)</a>.*?</h[^>]*>.*?<a[^>]*class="result__snippet"[^>]*>(.*?)</a>.*?</div>',
|
|
html,
|
|
re.DOTALL | re.IGNORECASE,
|
|
)
|
|
if result_blocks:
|
|
for title_raw, snippet_raw in result_blocks[:max_results]:
|
|
title = _strip_html(title_raw)
|
|
snippet = _strip_html(snippet_raw)
|
|
if title or snippet:
|
|
results.append(f"{title}\n{snippet}")
|
|
|
|
if not results:
|
|
snippets = re.findall(
|
|
r'<a[^>]*class="result__snippet"[^>]*>(.*?)</a>', html, re.DOTALL | re.IGNORECASE
|
|
)
|
|
titles = re.findall(
|
|
r'<h[^>]*class="result__title"[^>]*>.*?<a[^>]*>(.*?)</a>.*?</h[^>]*>',
|
|
html,
|
|
re.DOTALL | re.IGNORECASE,
|
|
)
|
|
for i in range(min(len(titles), len(snippets), max_results)):
|
|
title = _strip_html(titles[i])
|
|
snippet = _strip_html(snippets[i])
|
|
if title or snippet:
|
|
results.append(f"{title}\n{snippet}")
|
|
|
|
if results:
|
|
return "\n\n".join(results)
|
|
|
|
raise RuntimeError("DuckDuckGo 未解析到结果")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"DuckDuckGo 搜索失败: {e},回退到 Wikipedia")
|
|
wiki_text = await search_wikipedia(query, max_chars=2000)
|
|
if wiki_text:
|
|
return wiki_text
|
|
raise RuntimeError(f"所有搜索源均失败: {e}")
|
|
|
|
|
|
async def fetch_search_content(platform_name: str, keyword: str) -> str:
|
|
logger.info(f"[{platform_name}] 搜索查询: {keyword}")
|
|
return await search_duckduckgo(keyword, max_results=5)
|
|
|
|
|
|
class SearchOnlyAdapter(AIEngineAdapter):
|
|
def __init__(self, platform_name: str, **kwargs):
|
|
self._platform_name = platform_name
|
|
super().__init__(**kwargs)
|
|
|
|
def get_engine_type(self) -> EngineType:
|
|
return EngineType.DEEPSEEK
|
|
|
|
def _get_env_key(self) -> str | None:
|
|
return ""
|
|
|
|
async def query(
|
|
self,
|
|
query: str,
|
|
brand_name: str,
|
|
competitor_names: list[str] | None = None,
|
|
) -> AIQueryResult:
|
|
start_time = time.perf_counter()
|
|
content = await fetch_search_content(self._platform_name, query)
|
|
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
|
|
|
|
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
|
|
content, brand_name, competitor_names
|
|
)
|
|
|
|
return AIQueryResult(
|
|
engine_type=self.get_engine_type(),
|
|
query=query,
|
|
raw_response=content,
|
|
citations=[],
|
|
has_brand_citation=has_brand,
|
|
has_competitor_citation=has_comp,
|
|
brand_context=brand_ctx,
|
|
competitor_contexts=comp_ctx,
|
|
response_time_ms=elapsed_ms,
|
|
timestamp=datetime.now(UTC),
|
|
metadata={"platform_name": self._platform_name, "mode": "search_only"},
|
|
)
|
|
|
|
|
|
async def query_platform_raw(
|
|
platform_name: str,
|
|
keyword: str,
|
|
brand_name: str = "",
|
|
competitor_names: list[str] | None = None,
|
|
) -> str:
|
|
from .batch_query import _build_adapters
|
|
|
|
if is_search_only_platform(platform_name):
|
|
content = await fetch_search_content(platform_name, keyword)
|
|
return f"[data_source: search_engine]\n{content}"
|
|
|
|
engine_type = get_engine_type_for_platform(platform_name)
|
|
if engine_type is None:
|
|
raise ValueError(f"不支持的平台: {platform_name}")
|
|
|
|
adapters = _build_adapters()
|
|
adapter = adapters.get(engine_type.value)
|
|
if adapter is None:
|
|
raise ValueError(f"平台 {platform_name} 适配器未注册")
|
|
|
|
result = await adapter.query(keyword, brand_name, competitor_names)
|
|
return f"[data_source: ai_platform]\n{result.raw_response}"
|
|
|
|
|
|
_SUPPORTED_PLATFORMS = {
|
|
"wenxin", "kimi", "doubao", "tongyi",
|
|
"qingyan", "tiangong", "xinghuo",
|
|
}
|
|
|
|
|
|
async def execute_single_platform(
|
|
keyword: str,
|
|
platform: str,
|
|
target_brand: str,
|
|
brand_aliases: list,
|
|
) -> dict:
|
|
if platform not in _SUPPORTED_PLATFORMS:
|
|
raise ValueError(f"不支持的平台: {platform}")
|
|
|
|
from app.workers.citation_extractor import analyze_citations
|
|
|
|
search_keyword = f"{keyword} {target_brand}"
|
|
raw_response = await query_platform_raw(
|
|
platform_name=platform,
|
|
keyword=search_keyword,
|
|
brand_name=target_brand,
|
|
)
|
|
|
|
citation_analysis = analyze_citations(raw_response)
|
|
|
|
from app.workers.citation_engine import BrandMatcher, CompetitorDetector
|
|
|
|
matcher = BrandMatcher(target_brand=target_brand, brand_aliases=brand_aliases)
|
|
match_result = matcher.match(citation_analysis.clean_response)
|
|
|
|
competitor_detector = CompetitorDetector()
|
|
competitor_brands = competitor_detector.detect(
|
|
citation_analysis.clean_response, target_brand
|
|
)
|
|
|
|
source_urls = [
|
|
c.source_url for c in citation_analysis.citations if c.source_url
|
|
]
|
|
source_titles = [
|
|
c.source_title for c in citation_analysis.citations if c.source_title
|
|
]
|
|
citation_contexts = [
|
|
c.citation_context for c in citation_analysis.citations if c.citation_context
|
|
]
|
|
|
|
return {
|
|
"cited": match_result["cited"],
|
|
"confidence": match_result["confidence"],
|
|
"match_type": match_result["match_type"],
|
|
"position": match_result["position"],
|
|
"citation_text": match_result["citation_text"],
|
|
"competitor_brands": competitor_brands,
|
|
"raw_response": raw_response,
|
|
"data_source": citation_analysis.data_source,
|
|
"source_urls": source_urls,
|
|
"source_titles": source_titles,
|
|
"citation_contexts": citation_contexts,
|
|
"ai_response_text": citation_analysis.clean_response,
|
|
}
|