geo/backend/app/services/diagnosis/data_collector.py

425 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import asyncio
import logging
import re
from dataclasses import dataclass, field
from datetime import UTC, datetime
import httpx
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.citation_record import CitationRecord
from app.models.query import Query
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput
logger = logging.getLogger(__name__)
_DEFAULT_PLATFORMS = ["deepseek", "kimi"]
_QUERY_KEYWORDS = [
"{brand}是什么",
"{brand}怎么样",
"推荐{industry}品牌",
]
@dataclass
class DataCollectionResult:
diagnosis_input: GEODiagnosisInput
metadata: dict = field(default_factory=dict)
errors: list[str] = field(default_factory=list)
class DataCollectorService:
def __init__(self, db: AsyncSession):
self._db = db
async def collect(
self,
brand_name: str,
brand_aliases: list[str] | None = None,
website: str | None = None,
industry: str | None = None,
) -> DataCollectionResult:
errors: list[str] = []
metadata: dict = {
"brand_name": brand_name,
"collected_at": datetime.now(UTC).isoformat(),
"channels": {},
}
ai_task = asyncio.create_task(
self._collect_ai_platform_signals(
brand_name, brand_aliases or [], industry
)
)
citation_task = asyncio.create_task(
self._collect_citation_record_signals(brand_name, brand_aliases or [])
)
website_task = asyncio.create_task(
self._collect_website_signals(website)
)
ai_result, ai_err = await self._safe_await(ai_task, "ai_platform")
citation_result, citation_err = await self._safe_await(
citation_task, "citation_record"
)
website_result, website_err = await self._safe_await(website_task, "website")
if ai_err:
errors.append(ai_err)
if citation_err:
errors.append(citation_err)
if website_err:
errors.append(website_err)
metadata["channels"]["ai_platform"] = ai_result.get("metadata", {}) if ai_result else {"error": ai_err}
metadata["channels"]["citation_record"] = citation_result.get("metadata", {}) if citation_result else {"error": citation_err}
metadata["channels"]["website"] = website_result.get("metadata", {}) if website_result else {"error": website_err}
diagnosis_input = GEODiagnosisInput()
if ai_result:
self._apply_ai_signals(diagnosis_input, ai_result)
if citation_result:
self._apply_citation_signals(diagnosis_input, citation_result)
if website_result:
self._apply_website_signals(diagnosis_input, website_result)
if industry:
diagnosis_input.has_industry_classification = True
return DataCollectionResult(
diagnosis_input=diagnosis_input,
metadata=metadata,
errors=errors,
)
async def _collect_ai_platform_signals(
self,
brand_name: str,
brand_aliases: list[str],
industry: str | None,
) -> dict:
from app.services.ai_engine.platform_bridge import execute_single_platform
keywords = []
for tpl in _QUERY_KEYWORDS:
kw = tpl.format(brand=brand_name, industry=industry or "科技")
keywords.append(kw)
all_results: list[dict] = []
for platform in _DEFAULT_PLATFORMS:
for keyword in keywords[:2]:
try:
result = await execute_single_platform(
keyword=keyword,
platform=platform,
target_brand=brand_name,
brand_aliases=brand_aliases,
)
all_results.append(result)
except Exception as e:
logger.warning(f"AI platform query failed: platform={platform}, keyword={keyword}, error={e}")
total = len(all_results)
cited_count = sum(1 for r in all_results if r.get("cited"))
accurate_count = sum(
1 for r in all_results if r.get("match_type") == "exact"
)
aor = cited_count / total if total > 0 else 0.0
accuracy = accurate_count / cited_count if cited_count > 0 else 0.0
sov = aor * 0.6
competitor_mentions: dict[str, int] = {}
for r in all_results:
for comp in r.get("competitor_brands", []):
competitor_mentions[comp] = competitor_mentions.get(comp, 0) + 1
max_comp_mentions = max(competitor_mentions.values()) if competitor_mentions else 0
competitor_gap = max(0.0, (max_comp_mentions - cited_count) / total) if total > 0 else 0.5
return {
"total_responses": total,
"cited_count": cited_count,
"accurate_count": accurate_count,
"aor": aor,
"accuracy": accuracy,
"sov": sov,
"competitor_gap": competitor_gap,
"has_author_bio": cited_count > 0,
"author_credentials_complete": min(1.0, cited_count / 3) if cited_count > 0 else 0.0,
"has_data_sources": any(r.get("source_urls") for r in all_results),
"metadata": {
"platforms_queried": _DEFAULT_PLATFORMS,
"keywords_used": keywords[:2],
"total_responses": total,
"cited_count": cited_count,
},
}
async def _collect_citation_record_signals(
self,
brand_name: str,
brand_aliases: list[str],
) -> dict:
stmt = (
select(CitationRecord)
.join(Query, CitationRecord.query_id == Query.id)
.where(Query.target_brand == brand_name)
.order_by(CitationRecord.queried_at.desc())
.limit(100)
)
result = await self._db.execute(stmt)
records = result.scalars().all()
if not records:
return {
"total_responses": 0,
"cited_count": 0,
"accurate_count": 0,
"aor": 0.0,
"accuracy": 0.0,
"sov": 0.0,
"competitor_gap": 0.0,
"metadata": {"records_found": 0},
}
total = len(records)
cited_count = sum(1 for r in records if r.cited)
accurate_count = sum(
1 for r in records if r.match_type == "exact" and r.cited
)
aor = cited_count / total if total > 0 else 0.0
accuracy = accurate_count / cited_count if cited_count > 0 else 0.0
sov = aor * 0.5
competitor_all: dict[str, int] = {}
for r in records:
if r.competitor_brands and isinstance(r.competitor_brands, list):
for comp in r.competitor_brands:
if isinstance(comp, str):
competitor_all[comp] = competitor_all.get(comp, 0) + 1
max_comp = max(competitor_all.values()) if competitor_all else 0
competitor_gap = max(0.0, (max_comp - cited_count) / total) if total > 0 else 0.0
has_certifications = any(
r.sentiment == "positive" for r in records if r.sentiment
)
cert_count = sum(1 for r in records if r.sentiment == "positive")
has_endorsements = cited_count >= 3
endorsement_count = min(cited_count, 5)
return {
"total_responses": total,
"cited_count": cited_count,
"accurate_count": accurate_count,
"aor": aor,
"accuracy": accuracy,
"sov": min(sov, 1.0),
"competitor_gap": min(competitor_gap, 1.0),
"has_certifications": has_certifications,
"certification_count": cert_count,
"has_expert_endorsements": has_endorsements,
"endorsement_count": endorsement_count,
"content_depth_score": min(1.0, total / 20),
"topic_coverage_ratio": min(1.0, cited_count / 10),
"entity_consistency_score": min(1.0, accuracy * 1.1) if accuracy > 0 else 0.1,
"cluster_completeness": min(1.0, cited_count / 15),
"total_content_count": total,
"topic_cluster_count": min(cited_count, 10),
"metadata": {"records_found": total},
}
async def _collect_website_signals(self, website: str | None) -> dict:
if not website:
return {"metadata": {"skipped": True, "reason": "no_website"}}
try:
async with httpx.AsyncClient(
timeout=15, follow_redirects=True
) as client:
resp = await client.get(
website,
headers={
"User-Agent": (
"Mozilla/5.0 (compatible; GEO-Diagnosis-Bot/1.0)"
),
"Accept": "text/html",
},
)
resp.raise_for_status()
html = resp.text
except Exception as e:
logger.warning(f"Website fetch failed: {website}, error={e}")
return {"metadata": {"skipped": True, "reason": str(e)}}
signals = self._parse_html_signals(html)
signals["metadata"] = {"url": website, "html_length": len(html)}
return signals
def _parse_html_signals(self, html: str) -> dict:
signals: dict = {}
has_ld_json = 'application/ld+json' in html
signals["has_organization"] = (
has_ld_json and ('"Organization"' in html or '"organization"' in html)
)
signals["has_product"] = (
has_ld_json and ('"Product"' in html or '"product"' in html)
)
signals["has_article"] = (
has_ld_json
and ('"Article"' in html or '"BlogPosting"' in html or '"article"' in html)
)
signals["has_faq"] = (
has_ld_json and ('"FAQPage"' in html or '"faq"' in html)
)
signals["has_howto"] = (
has_ld_json and ('"HowTo"' in html or '"howto"' in html)
)
signals["has_breadcrumb"] = (
has_ld_json and ('"BreadcrumbList"' in html or '"breadcrumb"' in html)
)
h2_h3 = re.findall(r"<h[23][^>]*>(.*?)</h[23]>", html, re.DOTALL | re.IGNORECASE)
qa_pattern = re.compile(r"[?]|如何|什么|为什么|怎么|哪|多少|是否|可以")
qa_headings = [h for h in h2_h3 if qa_pattern.search(re.sub(r"<[^>]+>", "", h))]
signals["has_qa_headings"] = len(qa_headings) >= 2
signals["has_structured_data"] = (
"<ul" in html or "<ol" in html or "<table" in html
)
signals["has_internal_links"] = 'href="/' in html or 'href="./' in html
date_pattern = re.compile(
r"(20\d{2}[-/年]\d{1,2}[-/月]\d{1,2}[日]?)"
r"|(更新于|发布于|最后更新|published|updated|modified)"
)
signals["has_freshness_info"] = bool(date_pattern.search(html))
body_text = re.sub(r"<[^>]+>", " ", html)
body_text = re.sub(r"\s+", " ", body_text).strip()
first_500 = body_text[:500].lower()
signals["has_direct_answer"] = len(body_text) > 200 and len(first_500) > 100
signals["has_brand_definition"] = any(
kw in first_500
for kw in ["", "提供", "专注于", "致力于", "is a", "provides", "offers"]
)
audience_patterns = [
"为.*提供", "服务.*用户", "帮助.*企业", "面向",
"for ", "serves ", "helps ",
]
signals["has_target_audience"] = any(
re.search(p, first_500) for p in audience_patterns
)
value_patterns = [
"优势", "特色", "不同", "独特", "领先", "首创", "唯一",
"advantage", "unique", "leading", "first",
]
signals["has_unique_value"] = any(v in first_500 for v in value_patterns)
return signals
def _apply_ai_signals(self, inp: GEODiagnosisInput, data: dict) -> None:
inp.answer_ownership_rate = max(inp.answer_ownership_rate, data.get("aor", 0.0))
inp.citation_accuracy = max(inp.citation_accuracy, data.get("accuracy", 0.0))
inp.ai_sov = max(inp.ai_sov, data.get("sov", 0.0))
inp.competitor_gap = max(inp.competitor_gap, data.get("competitor_gap", 0.0))
inp.total_ai_responses = max(inp.total_ai_responses, data.get("total_responses", 0))
inp.brand_mention_count = max(inp.brand_mention_count, data.get("cited_count", 0))
inp.accurate_citation_count = max(
inp.accurate_citation_count, data.get("accurate_count", 0)
)
if data.get("has_author_bio"):
inp.has_author_bio = True
if data.get("author_credentials_complete", 0) > inp.author_credentials_complete:
inp.author_credentials_complete = data["author_credentials_complete"]
if data.get("has_data_sources"):
inp.has_data_sources = True
def _apply_citation_signals(self, inp: GEODiagnosisInput, data: dict) -> None:
inp.answer_ownership_rate = max(inp.answer_ownership_rate, data.get("aor", 0.0))
inp.citation_accuracy = max(inp.citation_accuracy, data.get("accuracy", 0.0))
inp.ai_sov = max(inp.ai_sov, data.get("sov", 0.0))
inp.competitor_gap = max(inp.competitor_gap, data.get("competitor_gap", 0.0))
inp.total_ai_responses = max(inp.total_ai_responses, data.get("total_responses", 0))
inp.brand_mention_count = max(inp.brand_mention_count, data.get("cited_count", 0))
inp.accurate_citation_count = max(
inp.accurate_citation_count, data.get("accurate_count", 0)
)
if data.get("has_certifications"):
inp.has_certifications = True
inp.certification_count = max(
inp.certification_count, data.get("certification_count", 0)
)
if data.get("has_expert_endorsements"):
inp.has_expert_endorsements = True
inp.endorsement_count = max(
inp.endorsement_count, data.get("endorsement_count", 0)
)
inp.content_depth_score = max(
inp.content_depth_score, data.get("content_depth_score", 0.0)
)
inp.topic_coverage_ratio = max(
inp.topic_coverage_ratio, data.get("topic_coverage_ratio", 0.0)
)
inp.entity_consistency_score = max(
inp.entity_consistency_score, data.get("entity_consistency_score", 0.0)
)
inp.cluster_completeness = max(
inp.cluster_completeness, data.get("cluster_completeness", 0.0)
)
inp.total_content_count = max(
inp.total_content_count, data.get("total_content_count", 0)
)
inp.topic_cluster_count = max(
inp.topic_cluster_count, data.get("topic_cluster_count", 0)
)
def _apply_website_signals(self, inp: GEODiagnosisInput, data: dict) -> None:
bool_fields = [
"has_direct_answer",
"has_qa_headings",
"has_structured_data",
"has_internal_links",
"has_freshness_info",
"has_brand_definition",
"has_target_audience",
"has_unique_value",
]
schema_fields = [
("has_organization", "has_organization"),
("has_product", "has_product"),
("has_article", "has_article"),
("has_faq", "has_faq"),
("has_howto", "has_howto"),
("has_breadcrumb", "has_breadcrumb"),
]
for f in bool_fields:
if data.get(f):
setattr(inp, f, True)
for data_key, inp_key in schema_fields:
if data.get(data_key):
setattr(inp, inp_key, True)
async def _safe_await(self, task: asyncio.Task, channel: str) -> tuple:
try:
result = await task
return result, None
except Exception as e:
logger.error(f"Data collection channel '{channel}' failed: {e}", exc_info=True)
return None, f"{channel}: {str(e)}"