import logging
import os
import re
import time
from urllib.parse import quote
from datetime import UTC, datetime
import httpx
from .base import AIEngineAdapter, AIQueryResult, EngineType
logger = logging.getLogger(__name__)
_PLATFORM_NAME_MAP: dict[str, EngineType] = {
"wenxin": EngineType.WENXIN,
"kimi": EngineType.KIMI,
"doubao": EngineType.DOUBAO,
"tongyi": EngineType.QWEN,
"deepseek": EngineType.DEEPSEEK,
"chatgpt": EngineType.CHATGPT,
"perplexity": EngineType.PERPLEXITY,
"gemini": EngineType.GEMINI,
"yuanbao": EngineType.YUANBAO,
}
_SEARCH_ONLY_PLATFORMS = {"qingyan", "tiangong", "xinghuo"}
def get_engine_type_for_platform(platform_name: str) -> EngineType | None:
return _PLATFORM_NAME_MAP.get(platform_name)
def is_search_only_platform(platform_name: str) -> bool:
return platform_name in _SEARCH_ONLY_PLATFORMS
async def search_wikipedia(keyword: str, max_chars: int = 2000) -> str:
search_url = "https://zh.wikipedia.org/w/api.php"
headers = {
"User-Agent": "GEO-Citation-Bot/1.0 (contact@example.com)",
}
async with httpx.AsyncClient(timeout=30) as client:
search_resp = await client.get(
search_url,
headers=headers,
params={
"action": "query",
"list": "search",
"srsearch": keyword,
"srlimit": 3,
"format": "json",
"origin": "*",
},
)
search_resp.raise_for_status()
search_data = search_resp.json()
search_results = search_data.get("query", {}).get("search", [])
if not search_results:
return ""
title = search_results[0]["title"]
async with httpx.AsyncClient(timeout=30) as client:
extract_resp = await client.get(
search_url,
headers=headers,
params={
"action": "query",
"prop": "extracts",
"titles": title,
"explaintext": True,
"exsentences": 15,
"format": "json",
"origin": "*",
},
)
extract_resp.raise_for_status()
extract_data = extract_resp.json()
pages = extract_data.get("query", {}).get("pages", {})
for page in pages.values():
extract = page.get("extract", "")
if extract:
extract = re.sub(r'\[\d+\]', '', extract)
extract = re.sub(r'\s+', ' ', extract).strip()
return extract[:max_chars]
return ""
def _strip_html(raw: str) -> str:
raw = raw.replace(" ", " ")
raw = raw.replace(""", '"')
raw = raw.replace("&", "&")
raw = raw.replace("<", "<")
raw = raw.replace(">", ">")
raw = raw.replace("'", "'")
text = re.sub(r"<[^>]+>", "", raw)
text = re.sub(r"\s+", " ", text).strip()
return text
async def search_duckduckgo(query: str, max_results: int = 5) -> str:
url = f"https://html.duckduckgo.com/html/?q={quote(query)}"
headers = {
"User-Agent": (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
),
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"Accept-Language": "zh-CN,zh;q=0.9,en-US;q=0.8,en;q=0.7",
}
try:
async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
resp = await client.get(url, headers=headers)
resp.raise_for_status()
html = resp.text
if "web-result" not in html and "result__snippet" not in html and "result__title" not in html:
raise RuntimeError("DuckDuckGo 返回了非结果页面")
results: list[str] = []
result_blocks = re.findall(
r'
',
html,
re.DOTALL | re.IGNORECASE,
)
if result_blocks:
for title_raw, snippet_raw in result_blocks[:max_results]:
title = _strip_html(title_raw)
snippet = _strip_html(snippet_raw)
if title or snippet:
results.append(f"{title}\n{snippet}")
if not results:
snippets = re.findall(
r']*class="result__snippet"[^>]*>(.*?)', html, re.DOTALL | re.IGNORECASE
)
titles = re.findall(
r']*class="result__title"[^>]*>.*?]*>(.*?).*?]*>',
html,
re.DOTALL | re.IGNORECASE,
)
for i in range(min(len(titles), len(snippets), max_results)):
title = _strip_html(titles[i])
snippet = _strip_html(snippets[i])
if title or snippet:
results.append(f"{title}\n{snippet}")
if results:
return "\n\n".join(results)
raise RuntimeError("DuckDuckGo 未解析到结果")
except Exception as e:
logger.warning(f"DuckDuckGo 搜索失败: {e},回退到 Wikipedia")
wiki_text = await search_wikipedia(query, max_chars=2000)
if wiki_text:
return wiki_text
raise RuntimeError(f"所有搜索源均失败: {e}")
async def fetch_search_content(platform_name: str, keyword: str) -> str:
logger.info(f"[{platform_name}] 搜索查询: {keyword}")
return await search_duckduckgo(keyword, max_results=5)
class SearchOnlyAdapter(AIEngineAdapter):
def __init__(self, platform_name: str, **kwargs):
self._platform_name = platform_name
super().__init__(**kwargs)
def get_engine_type(self) -> EngineType:
return EngineType.DEEPSEEK
def _get_env_key(self) -> str | None:
return ""
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
start_time = time.perf_counter()
content = await fetch_search_content(self._platform_name, query)
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
content, brand_name, competitor_names
)
return AIQueryResult(
engine_type=self.get_engine_type(),
query=query,
raw_response=content,
citations=[],
has_brand_citation=has_brand,
has_competitor_citation=has_comp,
brand_context=brand_ctx,
competitor_contexts=comp_ctx,
response_time_ms=elapsed_ms,
timestamp=datetime.now(UTC),
metadata={"platform_name": self._platform_name, "mode": "search_only"},
)
async def query_platform_raw(
platform_name: str,
keyword: str,
brand_name: str = "",
competitor_names: list[str] | None = None,
) -> str:
from .batch_query import _build_adapters
if is_search_only_platform(platform_name):
content = await fetch_search_content(platform_name, keyword)
return f"[data_source: search_engine]\n{content}"
engine_type = get_engine_type_for_platform(platform_name)
if engine_type is None:
raise ValueError(f"不支持的平台: {platform_name}")
adapters = _build_adapters()
adapter = adapters.get(engine_type.value)
if adapter is None:
raise ValueError(f"平台 {platform_name} 适配器未注册")
result = await adapter.query(keyword, brand_name, competitor_names)
return f"[data_source: ai_platform]\n{result.raw_response}"
_SUPPORTED_PLATFORMS = {
"wenxin", "kimi", "doubao", "tongyi",
"qingyan", "tiangong", "xinghuo",
}
async def execute_single_platform(
keyword: str,
platform: str,
target_brand: str,
brand_aliases: list,
) -> dict:
if platform not in _SUPPORTED_PLATFORMS:
raise ValueError(f"不支持的平台: {platform}")
from app.workers.citation_extractor import analyze_citations
search_keyword = f"{keyword} {target_brand}"
raw_response = await query_platform_raw(
platform_name=platform,
keyword=search_keyword,
brand_name=target_brand,
)
citation_analysis = analyze_citations(raw_response)
from app.workers.citation_engine import BrandMatcher, CompetitorDetector
matcher = BrandMatcher(target_brand=target_brand, brand_aliases=brand_aliases)
match_result = matcher.match(citation_analysis.clean_response)
competitor_detector = CompetitorDetector()
competitor_brands = competitor_detector.detect(
citation_analysis.clean_response, target_brand
)
source_urls = [
c.source_url for c in citation_analysis.citations if c.source_url
]
source_titles = [
c.source_title for c in citation_analysis.citations if c.source_title
]
citation_contexts = [
c.citation_context for c in citation_analysis.citations if c.citation_context
]
return {
"cited": match_result["cited"],
"confidence": match_result["confidence"],
"match_type": match_result["match_type"],
"position": match_result["position"],
"citation_text": match_result["citation_text"],
"competitor_brands": competitor_brands,
"raw_response": raw_response,
"data_source": citation_analysis.data_source,
"source_urls": source_urls,
"source_titles": source_titles,
"citation_contexts": citation_contexts,
"ai_response_text": citation_analysis.clean_response,
}