from unittest.mock import AsyncMock, MagicMock, patch import pytest import pytest_asyncio from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.pool import StaticPool from app.database import Base from app.models.brand import Brand from app.models.citation_record import CitationRecord from app.models.query import Query from app.models.user import User from app.services.auth import hash_password from app.services.diagnosis.data_collector import DataCollectorService, DataCollectionResult from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput @pytest_asyncio.fixture async def async_engine(): engine = create_async_engine( "sqlite+aiosqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest_asyncio.fixture async def async_session(async_engine): maker = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) async with maker() as session: yield session class TestWebsiteSignalParsing: def test_parse_html_with_schema_org(self): service = DataCollectorService.__new__(DataCollectorService) service._db = None html = """
TestBrand是一家专注于技术创新的公司,为企业提供智能化解决方案。
Hello world
" signals = service._parse_html_signals(html) assert signals["has_organization"] is False assert signals["has_product"] is False assert signals["has_qa_headings"] is False def test_parse_html_article_schema(self): service = DataCollectorService.__new__(DataCollectorService) service._db = None html = """ """ signals = service._parse_html_signals(html) assert signals["has_article"] is True assert signals["has_faq"] is True assert signals["has_breadcrumb"] is True class TestSignalApplication: def test_apply_ai_signals(self): service = DataCollectorService.__new__(DataCollectorService) service._db = None inp = GEODiagnosisInput() ai_data = { "aor": 0.4, "accuracy": 0.85, "sov": 0.25, "competitor_gap": 0.15, "total_responses": 10, "cited_count": 4, "accurate_count": 3, "has_author_bio": True, "author_credentials_complete": 0.7, "has_data_sources": True, } service._apply_ai_signals(inp, ai_data) assert inp.answer_ownership_rate == 0.4 assert inp.citation_accuracy == 0.85 assert inp.ai_sov == 0.25 assert inp.competitor_gap == 0.15 assert inp.total_ai_responses == 10 assert inp.brand_mention_count == 4 assert inp.accurate_citation_count == 3 assert inp.has_author_bio is True assert inp.author_credentials_complete == 0.7 assert inp.has_data_sources is True def test_apply_citation_signals(self): service = DataCollectorService.__new__(DataCollectorService) service._db = None inp = GEODiagnosisInput() citation_data = { "aor": 0.5, "accuracy": 0.9, "sov": 0.3, "competitor_gap": 0.1, "total_responses": 20, "cited_count": 10, "accurate_count": 9, "has_certifications": True, "certification_count": 3, "has_expert_endorsements": True, "endorsement_count": 5, "content_depth_score": 0.8, "topic_coverage_ratio": 0.7, "entity_consistency_score": 0.85, "cluster_completeness": 0.6, "total_content_count": 20, "topic_cluster_count": 8, } service._apply_citation_signals(inp, citation_data) assert inp.answer_ownership_rate == 0.5 assert inp.citation_accuracy == 0.9 assert inp.has_certifications is True assert inp.certification_count == 3 assert inp.content_depth_score == 0.8 def test_apply_website_signals(self): service = DataCollectorService.__new__(DataCollectorService) service._db = None inp = GEODiagnosisInput() website_data = { "has_direct_answer": True, "has_qa_headings": True, "has_structured_data": True, "has_internal_links": True, "has_freshness_info": True, "has_brand_definition": True, "has_target_audience": True, "has_unique_value": True, "has_organization": True, "has_product": True, "has_article": True, "has_faq": False, "has_howto": False, "has_breadcrumb": False, } service._apply_website_signals(inp, website_data) assert inp.has_direct_answer is True assert inp.has_qa_headings is True assert inp.has_organization is True assert inp.has_product is True assert inp.has_article is True assert inp.has_faq is False def test_signals_merge_max_values(self): service = DataCollectorService.__new__(DataCollectorService) service._db = None inp = GEODiagnosisInput() service._apply_ai_signals(inp, {"aor": 0.3, "accuracy": 0.7}) service._apply_citation_signals(inp, {"aor": 0.5, "accuracy": 0.9}) assert inp.answer_ownership_rate == 0.5 assert inp.citation_accuracy == 0.9 class TestDataCollectorIntegration: @pytest.mark.asyncio async def test_collect_with_no_data_sources(self, async_session): service = DataCollectorService(async_session) with patch.object( service, "_collect_ai_platform_signals", new_callable=AsyncMock ) as mock_ai, patch.object( service, "_collect_citation_record_signals", new_callable=AsyncMock ) as mock_cite, patch.object( service, "_collect_website_signals", new_callable=AsyncMock ) as mock_web: mock_ai.return_value = { "aor": 0.0, "accuracy": 0.0, "sov": 0.0, "competitor_gap": 0.5, "total_responses": 0, "cited_count": 0, "accurate_count": 0, "metadata": {}, } mock_cite.return_value = { "total_responses": 0, "cited_count": 0, "accurate_count": 0, "aor": 0.0, "accuracy": 0.0, "sov": 0.0, "competitor_gap": 0.0, "metadata": {"records_found": 0}, } mock_web.return_value = {"metadata": {"skipped": True, "reason": "no_website"}} result = await service.collect(brand_name="UnknownBrand") assert isinstance(result, DataCollectionResult) assert isinstance(result.diagnosis_input, GEODiagnosisInput) assert result.diagnosis_input.has_industry_classification is False @pytest.mark.asyncio async def test_collect_with_industry(self, async_session): service = DataCollectorService(async_session) with patch.object( service, "_collect_ai_platform_signals", new_callable=AsyncMock ) as mock_ai, patch.object( service, "_collect_citation_record_signals", new_callable=AsyncMock ) as mock_cite, patch.object( service, "_collect_website_signals", new_callable=AsyncMock ) as mock_web: mock_ai.return_value = { "aor": 0.0, "accuracy": 0.0, "sov": 0.0, "competitor_gap": 0.5, "total_responses": 0, "cited_count": 0, "accurate_count": 0, "metadata": {}, } mock_cite.return_value = { "total_responses": 0, "cited_count": 0, "accurate_count": 0, "aor": 0.0, "accuracy": 0.0, "sov": 0.0, "competitor_gap": 0.0, "metadata": {"records_found": 0}, } mock_web.return_value = {"metadata": {"skipped": True}} result = await service.collect( brand_name="TestBrand", industry="technology" ) assert result.diagnosis_input.has_industry_classification is True @pytest.mark.asyncio async def test_collect_produces_nonzero_with_website_signals( self, async_session ): service = DataCollectorService(async_session) with patch.object( service, "_collect_ai_platform_signals", new_callable=AsyncMock ) as mock_ai, patch.object( service, "_collect_citation_record_signals", new_callable=AsyncMock ) as mock_cite, patch.object( service, "_collect_website_signals", new_callable=AsyncMock ) as mock_web: mock_ai.return_value = { "aor": 0.2, "accuracy": 0.6, "sov": 0.1, "competitor_gap": 0.3, "total_responses": 5, "cited_count": 1, "accurate_count": 0, "has_author_bio": True, "author_credentials_complete": 0.5, "has_data_sources": False, "metadata": {}, } mock_cite.return_value = { "total_responses": 0, "cited_count": 0, "accurate_count": 0, "aor": 0.0, "accuracy": 0.0, "sov": 0.0, "competitor_gap": 0.0, "metadata": {"records_found": 0}, } mock_web.return_value = { "has_direct_answer": True, "has_qa_headings": True, "has_structured_data": True, "has_internal_links": True, "has_freshness_info": True, "has_brand_definition": True, "has_target_audience": True, "has_unique_value": True, "has_organization": True, "has_product": True, "has_article": True, "has_faq": False, "has_howto": False, "has_breadcrumb": False, "metadata": {"url": "https://test.com"}, } result = await service.collect( brand_name="TestBrand", website="https://test.com", industry="technology", ) from app.services.diagnosis.geo_diagnosis import GEODiagnosisService geo_service = GEODiagnosisService() diagnosis = geo_service.diagnose(result.diagnosis_input) assert diagnosis.overall_score > 0 assert len(diagnosis.dimensions) == 6 @pytest.mark.asyncio async def test_collect_handles_channel_failure(self, async_session): service = DataCollectorService(async_session) with patch.object( service, "_collect_ai_platform_signals", new_callable=AsyncMock ) as mock_ai, patch.object( service, "_collect_citation_record_signals", new_callable=AsyncMock ) as mock_cite, patch.object( service, "_collect_website_signals", new_callable=AsyncMock ) as mock_web: mock_ai.side_effect = Exception("AI platform unavailable") mock_cite.return_value = { "total_responses": 0, "cited_count": 0, "accurate_count": 0, "aor": 0.0, "accuracy": 0.0, "sov": 0.0, "competitor_gap": 0.0, "metadata": {"records_found": 0}, } mock_web.return_value = {"metadata": {"skipped": True}} result = await service.collect(brand_name="TestBrand") assert len(result.errors) >= 1 assert any("ai_platform" in e for e in result.errors)