import uuid from unittest.mock import AsyncMock, patch import pytest import pytest_asyncio from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.pool import StaticPool from app.api.deps import get_current_user, get_db from app.database import Base from app.main import app from app.models.brand import Brand from app.models.diagnosis_record import DiagnosisRecord from app.models.user import User from app.services.auth import hash_password from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput def _make_user( user_id: str | None = None, email: str = "test@example.com", plan: str = "free", ) -> User: uid = user_id or str(uuid.uuid4()) user = User( id=uid, email=email, password=hash_password("Test@123456"), firstName="Test", lastName="User", isActive=True, emailVerified=True, ) user.plan = plan user.max_queries = 50 if plan != "free" else 5 return user def _to_uuid(value: str | uuid.UUID) -> uuid.UUID: if isinstance(value, uuid.UUID): return value return uuid.UUID(str(value)) @pytest_asyncio.fixture async def async_engine(): engine = create_async_engine( "sqlite+aiosqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest_asyncio.fixture async def async_session(async_engine): maker = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) async with maker() as session: yield session @pytest_asyncio.fixture async def test_user(async_session): user = _make_user(plan="free") async_session.add(user) await async_session.commit() await async_session.refresh(user) return user @pytest_asyncio.fixture async def paid_user(async_session): user = _make_user(email="paid@example.com", plan="pro") async_session.add(user) await async_session.commit() await async_session.refresh(user) return user @pytest_asyncio.fixture async def test_brand(async_session, test_user): brand = Brand( id=uuid.uuid4(), user_id=_to_uuid(test_user.id), name="TestBrand", aliases=["TB", "Test Brand"], website="https://testbrand.com", industry="technology", platforms=["wenxin", "kimi"], frequency="weekly", status="active", ) async_session.add(brand) await async_session.commit() await async_session.refresh(brand) return brand @pytest_asyncio.fixture async def paid_brand(async_session, paid_user): brand = Brand( id=uuid.uuid4(), user_id=_to_uuid(paid_user.id), name="PaidBrand", aliases=["PB"], website="https://paidbrand.com", industry="technology", platforms=["wenxin", "kimi"], frequency="weekly", status="active", ) async_session.add(brand) await async_session.commit() await async_session.refresh(brand) return brand @pytest_asyncio.fixture async def async_client(async_session, test_user): async def override_get_db(): yield async_session async def override_get_current_user(): return test_user app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_current_user] = override_get_current_user transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client app.dependency_overrides.clear() @pytest_asyncio.fixture async def paid_client(async_session, paid_user): async def override_get_db(): yield async_session async def override_get_current_user(): return paid_user app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_current_user] = override_get_current_user transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client app.dependency_overrides.clear() class TestGEODiagnosisTriggerContract: @pytest.mark.asyncio async def test_trigger_returns_202(self, async_client, test_brand): with patch("app.api.diagnosis._run_geo_diagnosis", new_callable=AsyncMock): response = await async_client.post( f"/api/v1/diagnosis/geo/{test_brand.id}" ) assert response.status_code == 202 data = response.json() assert "task_id" in data assert "brand_id" in data assert "status" in data assert data["status"] in ("pending", "completed") assert data["brand_id"] == str(test_brand.id) @pytest.mark.asyncio async def test_trigger_brand_not_found(self, async_client): response = await async_client.post( f"/api/v1/diagnosis/geo/{uuid.uuid4()}" ) assert response.status_code == 404 @pytest.mark.asyncio async def test_trigger_force_refresh(self, async_client, test_brand): with patch("app.api.diagnosis._run_geo_diagnosis", new_callable=AsyncMock): response = await async_client.post( f"/api/v1/diagnosis/geo/{test_brand.id}", json={"force_refresh": True}, ) assert response.status_code == 202 class TestGEODiagnosisResultContract: @pytest.mark.asyncio async def test_result_pending(self, async_client, test_brand, async_session): record = DiagnosisRecord( brand_id=test_brand.id, user_id=_to_uuid(test_brand.user_id), diagnosis_type="geo", status="pending", ) async_session.add(record) await async_session.commit() await async_session.refresh(record) response = await async_client.get( f"/api/v1/diagnosis/geo/{test_brand.id}/result", params={"task_id": str(record.id)}, ) assert response.status_code == 200 data = response.json() assert data["status"] == "pending" assert data["result"] is None @pytest.mark.asyncio async def test_result_completed_nonzero_score( self, async_client, test_brand, async_session ): from app.services.diagnosis.geo_diagnosis import GEODiagnosisService input_data = GEODiagnosisInput( has_direct_answer=True, has_brand_definition=True, has_author_bio=True, author_credentials_complete=0.8, answer_ownership_rate=0.3, ) service = GEODiagnosisService() result = service.diagnose(input_data) record = DiagnosisRecord( brand_id=test_brand.id, user_id=_to_uuid(test_brand.user_id), diagnosis_type="geo", status="completed", overall_score=result.overall_score, result_json=result.to_dict(), ) async_session.add(record) await async_session.commit() await async_session.refresh(record) response = await async_client.get( f"/api/v1/diagnosis/geo/{test_brand.id}/result", params={"task_id": str(record.id)}, ) assert response.status_code == 200 data = response.json() assert data["status"] == "completed" assert data["result"] is not None assert data["result"]["overall_score"] > 0 @pytest.mark.asyncio async def test_result_not_found(self, async_client, test_brand): response = await async_client.get( f"/api/v1/diagnosis/geo/{test_brand.id}/result", params={"task_id": str(uuid.uuid4())}, ) assert response.status_code == 404 class TestGEODiagnosisFreeVsPaidContract: @pytest.mark.asyncio async def test_free_user_gets_3_dimensions( self, async_client, test_brand, async_session ): from app.services.diagnosis.geo_diagnosis import GEODiagnosisService input_data = GEODiagnosisInput( has_direct_answer=True, has_brand_definition=True, has_author_bio=True, author_credentials_complete=0.8, has_organization=True, content_depth_score=0.7, answer_ownership_rate=0.3, ) service = GEODiagnosisService() result = service.diagnose(input_data) record = DiagnosisRecord( brand_id=test_brand.id, user_id=_to_uuid(test_brand.user_id), diagnosis_type="geo", status="completed", overall_score=result.overall_score, result_json=result.to_dict(), ) async_session.add(record) await async_session.commit() await async_session.refresh(record) response = await async_client.get( f"/api/v1/diagnosis/geo/{test_brand.id}/result", params={"task_id": str(record.id)}, ) data = response.json() assert data["result"]["is_full_report"] is False dim_names = {d["name"] for d in data["result"]["dimensions"]} assert len(dim_names) == 3 assert "内容可提取性" in dim_names assert "E-E-A-T信号" in dim_names assert "引用就绪度" in dim_names @pytest.mark.asyncio async def test_paid_user_gets_6_dimensions( self, paid_client, paid_brand, async_session, paid_user ): from app.services.diagnosis.geo_diagnosis import GEODiagnosisService input_data = GEODiagnosisInput( has_direct_answer=True, has_brand_definition=True, has_author_bio=True, author_credentials_complete=0.8, has_organization=True, content_depth_score=0.7, answer_ownership_rate=0.3, ) service = GEODiagnosisService() result = service.diagnose(input_data) record = DiagnosisRecord( brand_id=paid_brand.id, user_id=_to_uuid(paid_user.id), diagnosis_type="geo", status="completed", overall_score=result.overall_score, result_json=result.to_dict(), ) async_session.add(record) await async_session.commit() await async_session.refresh(record) response = await paid_client.get( f"/api/v1/diagnosis/geo/{paid_brand.id}/result", params={"task_id": str(record.id)}, ) data = response.json() assert data["result"]["is_full_report"] is True assert len(data["result"]["dimensions"]) == 6 class TestGEODiagnosisHistoryContract: @pytest.mark.asyncio async def test_history_returns_list( self, async_client, test_brand, async_session ): record = DiagnosisRecord( brand_id=test_brand.id, user_id=_to_uuid(test_brand.user_id), diagnosis_type="geo", status="completed", overall_score=45.0, result_json={ "overall_score": 45.0, "health_level": "pass", }, ) async_session.add(record) await async_session.commit() response = await async_client.get( f"/api/v1/diagnosis/geo/{test_brand.id}/history" ) assert response.status_code == 200 data = response.json() assert "brand_id" in data assert "history" in data assert isinstance(data["history"], list) class TestGEODiagnosisDataCollectionContract: @pytest.mark.asyncio async def test_diagnosis_with_data_collection_produces_nonzero( self, async_client, test_brand ): with patch("app.api.diagnosis._run_geo_diagnosis", new_callable=AsyncMock): response = await async_client.post( f"/api/v1/diagnosis/geo/{test_brand.id}", json={"force_refresh": True}, ) assert response.status_code == 202 @pytest.mark.asyncio async def test_combined_diagnosis_uses_data_collector( self, async_client, test_brand ): response = await async_client.get( f"/api/v1/diagnosis/combined/{test_brand.id}" ) assert response.status_code == 200 data = response.json() assert "geo_score" in data assert "seo_score" in data assert "combined_score" in data