425 lines
16 KiB
Python
425 lines
16 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
import re
|
||
from dataclasses import dataclass, field
|
||
from datetime import UTC, datetime
|
||
|
||
import httpx
|
||
from sqlalchemy import func, select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.models.citation_record import CitationRecord
|
||
from app.models.query import Query
|
||
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_DEFAULT_PLATFORMS = ["deepseek", "kimi"]
|
||
_QUERY_KEYWORDS = [
|
||
"{brand}是什么",
|
||
"{brand}怎么样",
|
||
"推荐{industry}品牌",
|
||
]
|
||
|
||
|
||
@dataclass
|
||
class DataCollectionResult:
|
||
diagnosis_input: GEODiagnosisInput
|
||
metadata: dict = field(default_factory=dict)
|
||
errors: list[str] = field(default_factory=list)
|
||
|
||
|
||
class DataCollectorService:
|
||
def __init__(self, db: AsyncSession):
|
||
self._db = db
|
||
|
||
async def collect(
|
||
self,
|
||
brand_name: str,
|
||
brand_aliases: list[str] | None = None,
|
||
website: str | None = None,
|
||
industry: str | None = None,
|
||
) -> DataCollectionResult:
|
||
errors: list[str] = []
|
||
metadata: dict = {
|
||
"brand_name": brand_name,
|
||
"collected_at": datetime.now(UTC).isoformat(),
|
||
"channels": {},
|
||
}
|
||
|
||
ai_task = asyncio.create_task(
|
||
self._collect_ai_platform_signals(
|
||
brand_name, brand_aliases or [], industry
|
||
)
|
||
)
|
||
citation_task = asyncio.create_task(
|
||
self._collect_citation_record_signals(brand_name, brand_aliases or [])
|
||
)
|
||
website_task = asyncio.create_task(
|
||
self._collect_website_signals(website)
|
||
)
|
||
|
||
ai_result, ai_err = await self._safe_await(ai_task, "ai_platform")
|
||
citation_result, citation_err = await self._safe_await(
|
||
citation_task, "citation_record"
|
||
)
|
||
website_result, website_err = await self._safe_await(website_task, "website")
|
||
|
||
if ai_err:
|
||
errors.append(ai_err)
|
||
if citation_err:
|
||
errors.append(citation_err)
|
||
if website_err:
|
||
errors.append(website_err)
|
||
|
||
metadata["channels"]["ai_platform"] = ai_result.get("metadata", {}) if ai_result else {"error": ai_err}
|
||
metadata["channels"]["citation_record"] = citation_result.get("metadata", {}) if citation_result else {"error": citation_err}
|
||
metadata["channels"]["website"] = website_result.get("metadata", {}) if website_result else {"error": website_err}
|
||
|
||
diagnosis_input = GEODiagnosisInput()
|
||
|
||
if ai_result:
|
||
self._apply_ai_signals(diagnosis_input, ai_result)
|
||
if citation_result:
|
||
self._apply_citation_signals(diagnosis_input, citation_result)
|
||
if website_result:
|
||
self._apply_website_signals(diagnosis_input, website_result)
|
||
|
||
if industry:
|
||
diagnosis_input.has_industry_classification = True
|
||
|
||
return DataCollectionResult(
|
||
diagnosis_input=diagnosis_input,
|
||
metadata=metadata,
|
||
errors=errors,
|
||
)
|
||
|
||
async def _collect_ai_platform_signals(
|
||
self,
|
||
brand_name: str,
|
||
brand_aliases: list[str],
|
||
industry: str | None,
|
||
) -> dict:
|
||
from app.services.ai_engine.platform_bridge import execute_single_platform
|
||
|
||
keywords = []
|
||
for tpl in _QUERY_KEYWORDS:
|
||
kw = tpl.format(brand=brand_name, industry=industry or "科技")
|
||
keywords.append(kw)
|
||
|
||
all_results: list[dict] = []
|
||
for platform in _DEFAULT_PLATFORMS:
|
||
for keyword in keywords[:2]:
|
||
try:
|
||
result = await execute_single_platform(
|
||
keyword=keyword,
|
||
platform=platform,
|
||
target_brand=brand_name,
|
||
brand_aliases=brand_aliases,
|
||
)
|
||
all_results.append(result)
|
||
except Exception as e:
|
||
logger.warning(f"AI platform query failed: platform={platform}, keyword={keyword}, error={e}")
|
||
|
||
total = len(all_results)
|
||
cited_count = sum(1 for r in all_results if r.get("cited"))
|
||
accurate_count = sum(
|
||
1 for r in all_results if r.get("match_type") == "exact"
|
||
)
|
||
|
||
aor = cited_count / total if total > 0 else 0.0
|
||
accuracy = accurate_count / cited_count if cited_count > 0 else 0.0
|
||
sov = aor * 0.6
|
||
|
||
competitor_mentions: dict[str, int] = {}
|
||
for r in all_results:
|
||
for comp in r.get("competitor_brands", []):
|
||
competitor_mentions[comp] = competitor_mentions.get(comp, 0) + 1
|
||
|
||
max_comp_mentions = max(competitor_mentions.values()) if competitor_mentions else 0
|
||
competitor_gap = max(0.0, (max_comp_mentions - cited_count) / total) if total > 0 else 0.5
|
||
|
||
return {
|
||
"total_responses": total,
|
||
"cited_count": cited_count,
|
||
"accurate_count": accurate_count,
|
||
"aor": aor,
|
||
"accuracy": accuracy,
|
||
"sov": sov,
|
||
"competitor_gap": competitor_gap,
|
||
"has_author_bio": cited_count > 0,
|
||
"author_credentials_complete": min(1.0, cited_count / 3) if cited_count > 0 else 0.0,
|
||
"has_data_sources": any(r.get("source_urls") for r in all_results),
|
||
"metadata": {
|
||
"platforms_queried": _DEFAULT_PLATFORMS,
|
||
"keywords_used": keywords[:2],
|
||
"total_responses": total,
|
||
"cited_count": cited_count,
|
||
},
|
||
}
|
||
|
||
async def _collect_citation_record_signals(
|
||
self,
|
||
brand_name: str,
|
||
brand_aliases: list[str],
|
||
) -> dict:
|
||
stmt = (
|
||
select(CitationRecord)
|
||
.join(Query, CitationRecord.query_id == Query.id)
|
||
.where(Query.target_brand == brand_name)
|
||
.order_by(CitationRecord.queried_at.desc())
|
||
.limit(100)
|
||
)
|
||
result = await self._db.execute(stmt)
|
||
records = result.scalars().all()
|
||
|
||
if not records:
|
||
return {
|
||
"total_responses": 0,
|
||
"cited_count": 0,
|
||
"accurate_count": 0,
|
||
"aor": 0.0,
|
||
"accuracy": 0.0,
|
||
"sov": 0.0,
|
||
"competitor_gap": 0.0,
|
||
"metadata": {"records_found": 0},
|
||
}
|
||
|
||
total = len(records)
|
||
cited_count = sum(1 for r in records if r.cited)
|
||
accurate_count = sum(
|
||
1 for r in records if r.match_type == "exact" and r.cited
|
||
)
|
||
|
||
aor = cited_count / total if total > 0 else 0.0
|
||
accuracy = accurate_count / cited_count if cited_count > 0 else 0.0
|
||
|
||
sov = aor * 0.5
|
||
|
||
competitor_all: dict[str, int] = {}
|
||
for r in records:
|
||
if r.competitor_brands and isinstance(r.competitor_brands, list):
|
||
for comp in r.competitor_brands:
|
||
if isinstance(comp, str):
|
||
competitor_all[comp] = competitor_all.get(comp, 0) + 1
|
||
|
||
max_comp = max(competitor_all.values()) if competitor_all else 0
|
||
competitor_gap = max(0.0, (max_comp - cited_count) / total) if total > 0 else 0.0
|
||
|
||
has_certifications = any(
|
||
r.sentiment == "positive" for r in records if r.sentiment
|
||
)
|
||
cert_count = sum(1 for r in records if r.sentiment == "positive")
|
||
has_endorsements = cited_count >= 3
|
||
endorsement_count = min(cited_count, 5)
|
||
|
||
return {
|
||
"total_responses": total,
|
||
"cited_count": cited_count,
|
||
"accurate_count": accurate_count,
|
||
"aor": aor,
|
||
"accuracy": accuracy,
|
||
"sov": min(sov, 1.0),
|
||
"competitor_gap": min(competitor_gap, 1.0),
|
||
"has_certifications": has_certifications,
|
||
"certification_count": cert_count,
|
||
"has_expert_endorsements": has_endorsements,
|
||
"endorsement_count": endorsement_count,
|
||
"content_depth_score": min(1.0, total / 20),
|
||
"topic_coverage_ratio": min(1.0, cited_count / 10),
|
||
"entity_consistency_score": min(1.0, accuracy * 1.1) if accuracy > 0 else 0.1,
|
||
"cluster_completeness": min(1.0, cited_count / 15),
|
||
"total_content_count": total,
|
||
"topic_cluster_count": min(cited_count, 10),
|
||
"metadata": {"records_found": total},
|
||
}
|
||
|
||
async def _collect_website_signals(self, website: str | None) -> dict:
|
||
if not website:
|
||
return {"metadata": {"skipped": True, "reason": "no_website"}}
|
||
|
||
try:
|
||
async with httpx.AsyncClient(
|
||
timeout=15, follow_redirects=True
|
||
) as client:
|
||
resp = await client.get(
|
||
website,
|
||
headers={
|
||
"User-Agent": (
|
||
"Mozilla/5.0 (compatible; GEO-Diagnosis-Bot/1.0)"
|
||
),
|
||
"Accept": "text/html",
|
||
},
|
||
)
|
||
resp.raise_for_status()
|
||
html = resp.text
|
||
except Exception as e:
|
||
logger.warning(f"Website fetch failed: {website}, error={e}")
|
||
return {"metadata": {"skipped": True, "reason": str(e)}}
|
||
|
||
signals = self._parse_html_signals(html)
|
||
signals["metadata"] = {"url": website, "html_length": len(html)}
|
||
return signals
|
||
|
||
def _parse_html_signals(self, html: str) -> dict:
|
||
signals: dict = {}
|
||
|
||
has_ld_json = 'application/ld+json' in html
|
||
signals["has_organization"] = (
|
||
has_ld_json and ('"Organization"' in html or '"organization"' in html)
|
||
)
|
||
signals["has_product"] = (
|
||
has_ld_json and ('"Product"' in html or '"product"' in html)
|
||
)
|
||
signals["has_article"] = (
|
||
has_ld_json
|
||
and ('"Article"' in html or '"BlogPosting"' in html or '"article"' in html)
|
||
)
|
||
signals["has_faq"] = (
|
||
has_ld_json and ('"FAQPage"' in html or '"faq"' in html)
|
||
)
|
||
signals["has_howto"] = (
|
||
has_ld_json and ('"HowTo"' in html or '"howto"' in html)
|
||
)
|
||
signals["has_breadcrumb"] = (
|
||
has_ld_json and ('"BreadcrumbList"' in html or '"breadcrumb"' in html)
|
||
)
|
||
|
||
h2_h3 = re.findall(r"<h[23][^>]*>(.*?)</h[23]>", html, re.DOTALL | re.IGNORECASE)
|
||
qa_pattern = re.compile(r"[??]|如何|什么|为什么|怎么|哪|多少|是否|可以")
|
||
qa_headings = [h for h in h2_h3 if qa_pattern.search(re.sub(r"<[^>]+>", "", h))]
|
||
signals["has_qa_headings"] = len(qa_headings) >= 2
|
||
|
||
signals["has_structured_data"] = (
|
||
"<ul" in html or "<ol" in html or "<table" in html
|
||
)
|
||
|
||
signals["has_internal_links"] = 'href="/' in html or 'href="./' in html
|
||
|
||
date_pattern = re.compile(
|
||
r"(20\d{2}[-/年]\d{1,2}[-/月]\d{1,2}[日]?)"
|
||
r"|(更新于|发布于|最后更新|published|updated|modified)"
|
||
)
|
||
signals["has_freshness_info"] = bool(date_pattern.search(html))
|
||
|
||
body_text = re.sub(r"<[^>]+>", " ", html)
|
||
body_text = re.sub(r"\s+", " ", body_text).strip()
|
||
|
||
first_500 = body_text[:500].lower()
|
||
signals["has_direct_answer"] = len(body_text) > 200 and len(first_500) > 100
|
||
|
||
signals["has_brand_definition"] = any(
|
||
kw in first_500
|
||
for kw in ["是", "提供", "专注于", "致力于", "is a", "provides", "offers"]
|
||
)
|
||
|
||
audience_patterns = [
|
||
"为.*提供", "服务.*用户", "帮助.*企业", "面向",
|
||
"for ", "serves ", "helps ",
|
||
]
|
||
signals["has_target_audience"] = any(
|
||
re.search(p, first_500) for p in audience_patterns
|
||
)
|
||
|
||
value_patterns = [
|
||
"优势", "特色", "不同", "独特", "领先", "首创", "唯一",
|
||
"advantage", "unique", "leading", "first",
|
||
]
|
||
signals["has_unique_value"] = any(v in first_500 for v in value_patterns)
|
||
|
||
return signals
|
||
|
||
def _apply_ai_signals(self, inp: GEODiagnosisInput, data: dict) -> None:
|
||
inp.answer_ownership_rate = max(inp.answer_ownership_rate, data.get("aor", 0.0))
|
||
inp.citation_accuracy = max(inp.citation_accuracy, data.get("accuracy", 0.0))
|
||
inp.ai_sov = max(inp.ai_sov, data.get("sov", 0.0))
|
||
inp.competitor_gap = max(inp.competitor_gap, data.get("competitor_gap", 0.0))
|
||
inp.total_ai_responses = max(inp.total_ai_responses, data.get("total_responses", 0))
|
||
inp.brand_mention_count = max(inp.brand_mention_count, data.get("cited_count", 0))
|
||
inp.accurate_citation_count = max(
|
||
inp.accurate_citation_count, data.get("accurate_count", 0)
|
||
)
|
||
if data.get("has_author_bio"):
|
||
inp.has_author_bio = True
|
||
if data.get("author_credentials_complete", 0) > inp.author_credentials_complete:
|
||
inp.author_credentials_complete = data["author_credentials_complete"]
|
||
if data.get("has_data_sources"):
|
||
inp.has_data_sources = True
|
||
|
||
def _apply_citation_signals(self, inp: GEODiagnosisInput, data: dict) -> None:
|
||
inp.answer_ownership_rate = max(inp.answer_ownership_rate, data.get("aor", 0.0))
|
||
inp.citation_accuracy = max(inp.citation_accuracy, data.get("accuracy", 0.0))
|
||
inp.ai_sov = max(inp.ai_sov, data.get("sov", 0.0))
|
||
inp.competitor_gap = max(inp.competitor_gap, data.get("competitor_gap", 0.0))
|
||
inp.total_ai_responses = max(inp.total_ai_responses, data.get("total_responses", 0))
|
||
inp.brand_mention_count = max(inp.brand_mention_count, data.get("cited_count", 0))
|
||
inp.accurate_citation_count = max(
|
||
inp.accurate_citation_count, data.get("accurate_count", 0)
|
||
)
|
||
if data.get("has_certifications"):
|
||
inp.has_certifications = True
|
||
inp.certification_count = max(
|
||
inp.certification_count, data.get("certification_count", 0)
|
||
)
|
||
if data.get("has_expert_endorsements"):
|
||
inp.has_expert_endorsements = True
|
||
inp.endorsement_count = max(
|
||
inp.endorsement_count, data.get("endorsement_count", 0)
|
||
)
|
||
inp.content_depth_score = max(
|
||
inp.content_depth_score, data.get("content_depth_score", 0.0)
|
||
)
|
||
inp.topic_coverage_ratio = max(
|
||
inp.topic_coverage_ratio, data.get("topic_coverage_ratio", 0.0)
|
||
)
|
||
inp.entity_consistency_score = max(
|
||
inp.entity_consistency_score, data.get("entity_consistency_score", 0.0)
|
||
)
|
||
inp.cluster_completeness = max(
|
||
inp.cluster_completeness, data.get("cluster_completeness", 0.0)
|
||
)
|
||
inp.total_content_count = max(
|
||
inp.total_content_count, data.get("total_content_count", 0)
|
||
)
|
||
inp.topic_cluster_count = max(
|
||
inp.topic_cluster_count, data.get("topic_cluster_count", 0)
|
||
)
|
||
|
||
def _apply_website_signals(self, inp: GEODiagnosisInput, data: dict) -> None:
|
||
bool_fields = [
|
||
"has_direct_answer",
|
||
"has_qa_headings",
|
||
"has_structured_data",
|
||
"has_internal_links",
|
||
"has_freshness_info",
|
||
"has_brand_definition",
|
||
"has_target_audience",
|
||
"has_unique_value",
|
||
]
|
||
schema_fields = [
|
||
("has_organization", "has_organization"),
|
||
("has_product", "has_product"),
|
||
("has_article", "has_article"),
|
||
("has_faq", "has_faq"),
|
||
("has_howto", "has_howto"),
|
||
("has_breadcrumb", "has_breadcrumb"),
|
||
]
|
||
|
||
for f in bool_fields:
|
||
if data.get(f):
|
||
setattr(inp, f, True)
|
||
|
||
for data_key, inp_key in schema_fields:
|
||
if data.get(data_key):
|
||
setattr(inp, inp_key, True)
|
||
|
||
async def _safe_await(self, task: asyncio.Task, channel: str) -> tuple:
|
||
try:
|
||
result = await task
|
||
return result, None
|
||
except Exception as e:
|
||
logger.error(f"Data collection channel '{channel}' failed: {e}", exc_info=True)
|
||
return None, f"{channel}: {str(e)}"
|