382 lines
14 KiB
Python
382 lines
14 KiB
Python
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
import pytest_asyncio
|
||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||
from sqlalchemy.pool import StaticPool
|
||
|
||
from app.database import Base
|
||
from app.models.brand import Brand
|
||
from app.models.citation_record import CitationRecord
|
||
from app.models.query import Query
|
||
from app.models.user import User
|
||
from app.services.auth import hash_password
|
||
from app.services.diagnosis.data_collector import DataCollectorService, DataCollectionResult
|
||
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def async_engine():
|
||
engine = create_async_engine(
|
||
"sqlite+aiosqlite:///:memory:",
|
||
connect_args={"check_same_thread": False},
|
||
poolclass=StaticPool,
|
||
)
|
||
async with engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
yield engine
|
||
await engine.dispose()
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def async_session(async_engine):
|
||
maker = async_sessionmaker(
|
||
async_engine,
|
||
class_=AsyncSession,
|
||
expire_on_commit=False,
|
||
autoflush=False,
|
||
autocommit=False,
|
||
)
|
||
async with maker() as session:
|
||
yield session
|
||
|
||
|
||
class TestWebsiteSignalParsing:
|
||
def test_parse_html_with_schema_org(self):
|
||
service = DataCollectorService.__new__(DataCollectorService)
|
||
service._db = None
|
||
|
||
html = """
|
||
<html><head>
|
||
<script type="application/ld+json">{"@type": "Organization", "name": "Test"}</script>
|
||
<script type="application/ld+json">{"@type": "Product", "name": "Widget"}</script>
|
||
</head><body>
|
||
<h2>什么是TestBrand?</h2>
|
||
<p>TestBrand是一家专注于技术创新的公司,为企业提供智能化解决方案。</p>
|
||
<h3>如何使用TestBrand?</h3>
|
||
<ul><li>步骤1</li><li>步骤2</li></ul>
|
||
<a href="/about">关于我们</a>
|
||
<span>更新于2026年5月1日</span>
|
||
</body></html>
|
||
"""
|
||
signals = service._parse_html_signals(html)
|
||
|
||
assert signals["has_organization"] is True
|
||
assert signals["has_product"] is True
|
||
assert signals["has_qa_headings"] is True
|
||
assert signals["has_structured_data"] is True
|
||
assert signals["has_internal_links"] is True
|
||
assert signals["has_freshness_info"] is True
|
||
assert signals["has_brand_definition"] is True
|
||
assert signals["has_target_audience"] is True
|
||
|
||
def test_parse_html_minimal(self):
|
||
service = DataCollectorService.__new__(DataCollectorService)
|
||
service._db = None
|
||
|
||
html = "<html><body><p>Hello world</p></body></html>"
|
||
signals = service._parse_html_signals(html)
|
||
|
||
assert signals["has_organization"] is False
|
||
assert signals["has_product"] is False
|
||
assert signals["has_qa_headings"] is False
|
||
|
||
def test_parse_html_article_schema(self):
|
||
service = DataCollectorService.__new__(DataCollectorService)
|
||
service._db = None
|
||
|
||
html = """
|
||
<html><head>
|
||
<script type="application/ld+json">{"@type": "Article"}</script>
|
||
<script type="application/ld+json">{"@type": "FAQPage"}</script>
|
||
<script type="application/ld+json">{"@type": "BreadcrumbList"}</script>
|
||
</head><body></body></html>
|
||
"""
|
||
signals = service._parse_html_signals(html)
|
||
|
||
assert signals["has_article"] is True
|
||
assert signals["has_faq"] is True
|
||
assert signals["has_breadcrumb"] is True
|
||
|
||
|
||
class TestSignalApplication:
|
||
def test_apply_ai_signals(self):
|
||
service = DataCollectorService.__new__(DataCollectorService)
|
||
service._db = None
|
||
|
||
inp = GEODiagnosisInput()
|
||
ai_data = {
|
||
"aor": 0.4,
|
||
"accuracy": 0.85,
|
||
"sov": 0.25,
|
||
"competitor_gap": 0.15,
|
||
"total_responses": 10,
|
||
"cited_count": 4,
|
||
"accurate_count": 3,
|
||
"has_author_bio": True,
|
||
"author_credentials_complete": 0.7,
|
||
"has_data_sources": True,
|
||
}
|
||
service._apply_ai_signals(inp, ai_data)
|
||
|
||
assert inp.answer_ownership_rate == 0.4
|
||
assert inp.citation_accuracy == 0.85
|
||
assert inp.ai_sov == 0.25
|
||
assert inp.competitor_gap == 0.15
|
||
assert inp.total_ai_responses == 10
|
||
assert inp.brand_mention_count == 4
|
||
assert inp.accurate_citation_count == 3
|
||
assert inp.has_author_bio is True
|
||
assert inp.author_credentials_complete == 0.7
|
||
assert inp.has_data_sources is True
|
||
|
||
def test_apply_citation_signals(self):
|
||
service = DataCollectorService.__new__(DataCollectorService)
|
||
service._db = None
|
||
|
||
inp = GEODiagnosisInput()
|
||
citation_data = {
|
||
"aor": 0.5,
|
||
"accuracy": 0.9,
|
||
"sov": 0.3,
|
||
"competitor_gap": 0.1,
|
||
"total_responses": 20,
|
||
"cited_count": 10,
|
||
"accurate_count": 9,
|
||
"has_certifications": True,
|
||
"certification_count": 3,
|
||
"has_expert_endorsements": True,
|
||
"endorsement_count": 5,
|
||
"content_depth_score": 0.8,
|
||
"topic_coverage_ratio": 0.7,
|
||
"entity_consistency_score": 0.85,
|
||
"cluster_completeness": 0.6,
|
||
"total_content_count": 20,
|
||
"topic_cluster_count": 8,
|
||
}
|
||
service._apply_citation_signals(inp, citation_data)
|
||
|
||
assert inp.answer_ownership_rate == 0.5
|
||
assert inp.citation_accuracy == 0.9
|
||
assert inp.has_certifications is True
|
||
assert inp.certification_count == 3
|
||
assert inp.content_depth_score == 0.8
|
||
|
||
def test_apply_website_signals(self):
|
||
service = DataCollectorService.__new__(DataCollectorService)
|
||
service._db = None
|
||
|
||
inp = GEODiagnosisInput()
|
||
website_data = {
|
||
"has_direct_answer": True,
|
||
"has_qa_headings": True,
|
||
"has_structured_data": True,
|
||
"has_internal_links": True,
|
||
"has_freshness_info": True,
|
||
"has_brand_definition": True,
|
||
"has_target_audience": True,
|
||
"has_unique_value": True,
|
||
"has_organization": True,
|
||
"has_product": True,
|
||
"has_article": True,
|
||
"has_faq": False,
|
||
"has_howto": False,
|
||
"has_breadcrumb": False,
|
||
}
|
||
service._apply_website_signals(inp, website_data)
|
||
|
||
assert inp.has_direct_answer is True
|
||
assert inp.has_qa_headings is True
|
||
assert inp.has_organization is True
|
||
assert inp.has_product is True
|
||
assert inp.has_article is True
|
||
assert inp.has_faq is False
|
||
|
||
def test_signals_merge_max_values(self):
|
||
service = DataCollectorService.__new__(DataCollectorService)
|
||
service._db = None
|
||
|
||
inp = GEODiagnosisInput()
|
||
service._apply_ai_signals(inp, {"aor": 0.3, "accuracy": 0.7})
|
||
service._apply_citation_signals(inp, {"aor": 0.5, "accuracy": 0.9})
|
||
|
||
assert inp.answer_ownership_rate == 0.5
|
||
assert inp.citation_accuracy == 0.9
|
||
|
||
|
||
class TestDataCollectorIntegration:
|
||
@pytest.mark.asyncio
|
||
async def test_collect_with_no_data_sources(self, async_session):
|
||
service = DataCollectorService(async_session)
|
||
|
||
with patch.object(
|
||
service, "_collect_ai_platform_signals", new_callable=AsyncMock
|
||
) as mock_ai, patch.object(
|
||
service, "_collect_citation_record_signals", new_callable=AsyncMock
|
||
) as mock_cite, patch.object(
|
||
service, "_collect_website_signals", new_callable=AsyncMock
|
||
) as mock_web:
|
||
mock_ai.return_value = {
|
||
"aor": 0.0,
|
||
"accuracy": 0.0,
|
||
"sov": 0.0,
|
||
"competitor_gap": 0.5,
|
||
"total_responses": 0,
|
||
"cited_count": 0,
|
||
"accurate_count": 0,
|
||
"metadata": {},
|
||
}
|
||
mock_cite.return_value = {
|
||
"total_responses": 0,
|
||
"cited_count": 0,
|
||
"accurate_count": 0,
|
||
"aor": 0.0,
|
||
"accuracy": 0.0,
|
||
"sov": 0.0,
|
||
"competitor_gap": 0.0,
|
||
"metadata": {"records_found": 0},
|
||
}
|
||
mock_web.return_value = {"metadata": {"skipped": True, "reason": "no_website"}}
|
||
|
||
result = await service.collect(brand_name="UnknownBrand")
|
||
|
||
assert isinstance(result, DataCollectionResult)
|
||
assert isinstance(result.diagnosis_input, GEODiagnosisInput)
|
||
assert result.diagnosis_input.has_industry_classification is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_collect_with_industry(self, async_session):
|
||
service = DataCollectorService(async_session)
|
||
|
||
with patch.object(
|
||
service, "_collect_ai_platform_signals", new_callable=AsyncMock
|
||
) as mock_ai, patch.object(
|
||
service, "_collect_citation_record_signals", new_callable=AsyncMock
|
||
) as mock_cite, patch.object(
|
||
service, "_collect_website_signals", new_callable=AsyncMock
|
||
) as mock_web:
|
||
mock_ai.return_value = {
|
||
"aor": 0.0,
|
||
"accuracy": 0.0,
|
||
"sov": 0.0,
|
||
"competitor_gap": 0.5,
|
||
"total_responses": 0,
|
||
"cited_count": 0,
|
||
"accurate_count": 0,
|
||
"metadata": {},
|
||
}
|
||
mock_cite.return_value = {
|
||
"total_responses": 0,
|
||
"cited_count": 0,
|
||
"accurate_count": 0,
|
||
"aor": 0.0,
|
||
"accuracy": 0.0,
|
||
"sov": 0.0,
|
||
"competitor_gap": 0.0,
|
||
"metadata": {"records_found": 0},
|
||
}
|
||
mock_web.return_value = {"metadata": {"skipped": True}}
|
||
|
||
result = await service.collect(
|
||
brand_name="TestBrand", industry="technology"
|
||
)
|
||
|
||
assert result.diagnosis_input.has_industry_classification is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_collect_produces_nonzero_with_website_signals(
|
||
self, async_session
|
||
):
|
||
service = DataCollectorService(async_session)
|
||
|
||
with patch.object(
|
||
service, "_collect_ai_platform_signals", new_callable=AsyncMock
|
||
) as mock_ai, patch.object(
|
||
service, "_collect_citation_record_signals", new_callable=AsyncMock
|
||
) as mock_cite, patch.object(
|
||
service, "_collect_website_signals", new_callable=AsyncMock
|
||
) as mock_web:
|
||
mock_ai.return_value = {
|
||
"aor": 0.2,
|
||
"accuracy": 0.6,
|
||
"sov": 0.1,
|
||
"competitor_gap": 0.3,
|
||
"total_responses": 5,
|
||
"cited_count": 1,
|
||
"accurate_count": 0,
|
||
"has_author_bio": True,
|
||
"author_credentials_complete": 0.5,
|
||
"has_data_sources": False,
|
||
"metadata": {},
|
||
}
|
||
mock_cite.return_value = {
|
||
"total_responses": 0,
|
||
"cited_count": 0,
|
||
"accurate_count": 0,
|
||
"aor": 0.0,
|
||
"accuracy": 0.0,
|
||
"sov": 0.0,
|
||
"competitor_gap": 0.0,
|
||
"metadata": {"records_found": 0},
|
||
}
|
||
mock_web.return_value = {
|
||
"has_direct_answer": True,
|
||
"has_qa_headings": True,
|
||
"has_structured_data": True,
|
||
"has_internal_links": True,
|
||
"has_freshness_info": True,
|
||
"has_brand_definition": True,
|
||
"has_target_audience": True,
|
||
"has_unique_value": True,
|
||
"has_organization": True,
|
||
"has_product": True,
|
||
"has_article": True,
|
||
"has_faq": False,
|
||
"has_howto": False,
|
||
"has_breadcrumb": False,
|
||
"metadata": {"url": "https://test.com"},
|
||
}
|
||
|
||
result = await service.collect(
|
||
brand_name="TestBrand",
|
||
website="https://test.com",
|
||
industry="technology",
|
||
)
|
||
|
||
from app.services.diagnosis.geo_diagnosis import GEODiagnosisService
|
||
|
||
geo_service = GEODiagnosisService()
|
||
diagnosis = geo_service.diagnose(result.diagnosis_input)
|
||
|
||
assert diagnosis.overall_score > 0
|
||
assert len(diagnosis.dimensions) == 6
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_collect_handles_channel_failure(self, async_session):
|
||
service = DataCollectorService(async_session)
|
||
|
||
with patch.object(
|
||
service, "_collect_ai_platform_signals", new_callable=AsyncMock
|
||
) as mock_ai, patch.object(
|
||
service, "_collect_citation_record_signals", new_callable=AsyncMock
|
||
) as mock_cite, patch.object(
|
||
service, "_collect_website_signals", new_callable=AsyncMock
|
||
) as mock_web:
|
||
mock_ai.side_effect = Exception("AI platform unavailable")
|
||
mock_cite.return_value = {
|
||
"total_responses": 0,
|
||
"cited_count": 0,
|
||
"accurate_count": 0,
|
||
"aor": 0.0,
|
||
"accuracy": 0.0,
|
||
"sov": 0.0,
|
||
"competitor_gap": 0.0,
|
||
"metadata": {"records_found": 0},
|
||
}
|
||
mock_web.return_value = {"metadata": {"skipped": True}}
|
||
|
||
result = await service.collect(brand_name="TestBrand")
|
||
|
||
assert len(result.errors) >= 1
|
||
assert any("ai_platform" in e for e in result.errors)
|