geo/backend/tests/test_services/test_data_collector.py

382 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.models.brand import Brand
from app.models.citation_record import CitationRecord
from app.models.query import Query
from app.models.user import User
from app.services.auth import hash_password
from app.services.diagnosis.data_collector import DataCollectorService, DataCollectionResult
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput
@pytest_asyncio.fixture
async def async_engine():
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with maker() as session:
yield session
class TestWebsiteSignalParsing:
def test_parse_html_with_schema_org(self):
service = DataCollectorService.__new__(DataCollectorService)
service._db = None
html = """
<html><head>
<script type="application/ld+json">{"@type": "Organization", "name": "Test"}</script>
<script type="application/ld+json">{"@type": "Product", "name": "Widget"}</script>
</head><body>
<h2>什么是TestBrand</h2>
<p>TestBrand是一家专注于技术创新的公司为企业提供智能化解决方案。</p>
<h3>如何使用TestBrand</h3>
<ul><li>步骤1</li><li>步骤2</li></ul>
<a href="/about">关于我们</a>
<span>更新于2026年5月1日</span>
</body></html>
"""
signals = service._parse_html_signals(html)
assert signals["has_organization"] is True
assert signals["has_product"] is True
assert signals["has_qa_headings"] is True
assert signals["has_structured_data"] is True
assert signals["has_internal_links"] is True
assert signals["has_freshness_info"] is True
assert signals["has_brand_definition"] is True
assert signals["has_target_audience"] is True
def test_parse_html_minimal(self):
service = DataCollectorService.__new__(DataCollectorService)
service._db = None
html = "<html><body><p>Hello world</p></body></html>"
signals = service._parse_html_signals(html)
assert signals["has_organization"] is False
assert signals["has_product"] is False
assert signals["has_qa_headings"] is False
def test_parse_html_article_schema(self):
service = DataCollectorService.__new__(DataCollectorService)
service._db = None
html = """
<html><head>
<script type="application/ld+json">{"@type": "Article"}</script>
<script type="application/ld+json">{"@type": "FAQPage"}</script>
<script type="application/ld+json">{"@type": "BreadcrumbList"}</script>
</head><body></body></html>
"""
signals = service._parse_html_signals(html)
assert signals["has_article"] is True
assert signals["has_faq"] is True
assert signals["has_breadcrumb"] is True
class TestSignalApplication:
def test_apply_ai_signals(self):
service = DataCollectorService.__new__(DataCollectorService)
service._db = None
inp = GEODiagnosisInput()
ai_data = {
"aor": 0.4,
"accuracy": 0.85,
"sov": 0.25,
"competitor_gap": 0.15,
"total_responses": 10,
"cited_count": 4,
"accurate_count": 3,
"has_author_bio": True,
"author_credentials_complete": 0.7,
"has_data_sources": True,
}
service._apply_ai_signals(inp, ai_data)
assert inp.answer_ownership_rate == 0.4
assert inp.citation_accuracy == 0.85
assert inp.ai_sov == 0.25
assert inp.competitor_gap == 0.15
assert inp.total_ai_responses == 10
assert inp.brand_mention_count == 4
assert inp.accurate_citation_count == 3
assert inp.has_author_bio is True
assert inp.author_credentials_complete == 0.7
assert inp.has_data_sources is True
def test_apply_citation_signals(self):
service = DataCollectorService.__new__(DataCollectorService)
service._db = None
inp = GEODiagnosisInput()
citation_data = {
"aor": 0.5,
"accuracy": 0.9,
"sov": 0.3,
"competitor_gap": 0.1,
"total_responses": 20,
"cited_count": 10,
"accurate_count": 9,
"has_certifications": True,
"certification_count": 3,
"has_expert_endorsements": True,
"endorsement_count": 5,
"content_depth_score": 0.8,
"topic_coverage_ratio": 0.7,
"entity_consistency_score": 0.85,
"cluster_completeness": 0.6,
"total_content_count": 20,
"topic_cluster_count": 8,
}
service._apply_citation_signals(inp, citation_data)
assert inp.answer_ownership_rate == 0.5
assert inp.citation_accuracy == 0.9
assert inp.has_certifications is True
assert inp.certification_count == 3
assert inp.content_depth_score == 0.8
def test_apply_website_signals(self):
service = DataCollectorService.__new__(DataCollectorService)
service._db = None
inp = GEODiagnosisInput()
website_data = {
"has_direct_answer": True,
"has_qa_headings": True,
"has_structured_data": True,
"has_internal_links": True,
"has_freshness_info": True,
"has_brand_definition": True,
"has_target_audience": True,
"has_unique_value": True,
"has_organization": True,
"has_product": True,
"has_article": True,
"has_faq": False,
"has_howto": False,
"has_breadcrumb": False,
}
service._apply_website_signals(inp, website_data)
assert inp.has_direct_answer is True
assert inp.has_qa_headings is True
assert inp.has_organization is True
assert inp.has_product is True
assert inp.has_article is True
assert inp.has_faq is False
def test_signals_merge_max_values(self):
service = DataCollectorService.__new__(DataCollectorService)
service._db = None
inp = GEODiagnosisInput()
service._apply_ai_signals(inp, {"aor": 0.3, "accuracy": 0.7})
service._apply_citation_signals(inp, {"aor": 0.5, "accuracy": 0.9})
assert inp.answer_ownership_rate == 0.5
assert inp.citation_accuracy == 0.9
class TestDataCollectorIntegration:
@pytest.mark.asyncio
async def test_collect_with_no_data_sources(self, async_session):
service = DataCollectorService(async_session)
with patch.object(
service, "_collect_ai_platform_signals", new_callable=AsyncMock
) as mock_ai, patch.object(
service, "_collect_citation_record_signals", new_callable=AsyncMock
) as mock_cite, patch.object(
service, "_collect_website_signals", new_callable=AsyncMock
) as mock_web:
mock_ai.return_value = {
"aor": 0.0,
"accuracy": 0.0,
"sov": 0.0,
"competitor_gap": 0.5,
"total_responses": 0,
"cited_count": 0,
"accurate_count": 0,
"metadata": {},
}
mock_cite.return_value = {
"total_responses": 0,
"cited_count": 0,
"accurate_count": 0,
"aor": 0.0,
"accuracy": 0.0,
"sov": 0.0,
"competitor_gap": 0.0,
"metadata": {"records_found": 0},
}
mock_web.return_value = {"metadata": {"skipped": True, "reason": "no_website"}}
result = await service.collect(brand_name="UnknownBrand")
assert isinstance(result, DataCollectionResult)
assert isinstance(result.diagnosis_input, GEODiagnosisInput)
assert result.diagnosis_input.has_industry_classification is False
@pytest.mark.asyncio
async def test_collect_with_industry(self, async_session):
service = DataCollectorService(async_session)
with patch.object(
service, "_collect_ai_platform_signals", new_callable=AsyncMock
) as mock_ai, patch.object(
service, "_collect_citation_record_signals", new_callable=AsyncMock
) as mock_cite, patch.object(
service, "_collect_website_signals", new_callable=AsyncMock
) as mock_web:
mock_ai.return_value = {
"aor": 0.0,
"accuracy": 0.0,
"sov": 0.0,
"competitor_gap": 0.5,
"total_responses": 0,
"cited_count": 0,
"accurate_count": 0,
"metadata": {},
}
mock_cite.return_value = {
"total_responses": 0,
"cited_count": 0,
"accurate_count": 0,
"aor": 0.0,
"accuracy": 0.0,
"sov": 0.0,
"competitor_gap": 0.0,
"metadata": {"records_found": 0},
}
mock_web.return_value = {"metadata": {"skipped": True}}
result = await service.collect(
brand_name="TestBrand", industry="technology"
)
assert result.diagnosis_input.has_industry_classification is True
@pytest.mark.asyncio
async def test_collect_produces_nonzero_with_website_signals(
self, async_session
):
service = DataCollectorService(async_session)
with patch.object(
service, "_collect_ai_platform_signals", new_callable=AsyncMock
) as mock_ai, patch.object(
service, "_collect_citation_record_signals", new_callable=AsyncMock
) as mock_cite, patch.object(
service, "_collect_website_signals", new_callable=AsyncMock
) as mock_web:
mock_ai.return_value = {
"aor": 0.2,
"accuracy": 0.6,
"sov": 0.1,
"competitor_gap": 0.3,
"total_responses": 5,
"cited_count": 1,
"accurate_count": 0,
"has_author_bio": True,
"author_credentials_complete": 0.5,
"has_data_sources": False,
"metadata": {},
}
mock_cite.return_value = {
"total_responses": 0,
"cited_count": 0,
"accurate_count": 0,
"aor": 0.0,
"accuracy": 0.0,
"sov": 0.0,
"competitor_gap": 0.0,
"metadata": {"records_found": 0},
}
mock_web.return_value = {
"has_direct_answer": True,
"has_qa_headings": True,
"has_structured_data": True,
"has_internal_links": True,
"has_freshness_info": True,
"has_brand_definition": True,
"has_target_audience": True,
"has_unique_value": True,
"has_organization": True,
"has_product": True,
"has_article": True,
"has_faq": False,
"has_howto": False,
"has_breadcrumb": False,
"metadata": {"url": "https://test.com"},
}
result = await service.collect(
brand_name="TestBrand",
website="https://test.com",
industry="technology",
)
from app.services.diagnosis.geo_diagnosis import GEODiagnosisService
geo_service = GEODiagnosisService()
diagnosis = geo_service.diagnose(result.diagnosis_input)
assert diagnosis.overall_score > 0
assert len(diagnosis.dimensions) == 6
@pytest.mark.asyncio
async def test_collect_handles_channel_failure(self, async_session):
service = DataCollectorService(async_session)
with patch.object(
service, "_collect_ai_platform_signals", new_callable=AsyncMock
) as mock_ai, patch.object(
service, "_collect_citation_record_signals", new_callable=AsyncMock
) as mock_cite, patch.object(
service, "_collect_website_signals", new_callable=AsyncMock
) as mock_web:
mock_ai.side_effect = Exception("AI platform unavailable")
mock_cite.return_value = {
"total_responses": 0,
"cited_count": 0,
"accurate_count": 0,
"aor": 0.0,
"accuracy": 0.0,
"sov": 0.0,
"competitor_gap": 0.0,
"metadata": {"records_found": 0},
}
mock_web.return_value = {"metadata": {"skipped": True}}
result = await service.collect(brand_name="TestBrand")
assert len(result.errors) >= 1
assert any("ai_platform" in e for e in result.errors)