import uuid from datetime import UTC, datetime, timedelta import pytest import pytest_asyncio from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.pool import StaticPool from app.api.deps import get_current_user, get_db from app.database import Base from app.main import app from app.models.attribution_record import AttributionRecord from app.models.brand import Brand from app.models.diagnosis_record import DiagnosisRecord from app.models.user import User from app.services.auth import hash_password from tests.fixtures.auth import _to_uuid def _make_user( user_id: str | None = None, email: str = "test@example.com", plan: str = "free", ) -> User: uid = user_id or str(uuid.uuid4()) user = User( id=uid, email=email, password=hash_password("Test@123456"), firstName="Test", lastName="User", isActive=True, emailVerified=True, ) user.plan = plan user.max_queries = 50 if plan != "free" else 5 return user def _to_uuid(value: str | uuid.UUID) -> uuid.UUID: if isinstance(value, uuid.UUID): return value return uuid.UUID(str(value)) @pytest_asyncio.fixture async def async_engine(): engine = create_async_engine( "sqlite+aiosqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest_asyncio.fixture async def async_session(async_engine): maker = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) async with maker() as session: yield session @pytest_asyncio.fixture async def test_user(async_session): user = _make_user(plan="pro") async_session.add(user) await async_session.commit() await async_session.refresh(user) return user @pytest_asyncio.fixture async def test_brand(async_session, test_user): brand = Brand( id=uuid.uuid4(), user_id=_to_uuid(test_user.id), name="TestBrand", aliases=["TB"], website="https://testbrand.com", industry="technology", platforms=["wenxin"], frequency="weekly", status="active", ) async_session.add(brand) await async_session.commit() await async_session.refresh(brand) return brand @pytest_asyncio.fixture async def test_diagnosis(async_session, test_brand, test_user): record = DiagnosisRecord( brand_id=test_brand.id, user_id=_to_uuid(test_user.id), diagnosis_type="geo", status="completed", overall_score=45.0, result_json={ "overall_score": 45.0, "health_level": "pass", "dimensions": [ {"name": "内容可提取性", "score": 50}, {"name": "E-E-A-T信号", "score": 40}, {"name": "引用就绪度", "score": 45}, ], }, completed_at=datetime.now(UTC), ) async_session.add(record) await async_session.commit() await async_session.refresh(record) return record @pytest_asyncio.fixture async def async_client(async_session, test_user): async def override_get_db(): yield async_session async def override_get_current_user(): return test_user app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_current_user] = override_get_current_user transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client app.dependency_overrides.clear() def _make_attribution_record( user_id: str, brand_id: uuid.UUID, baseline_score: float = 45.0, current_score: float | None = None, score_delta: float | None = None, status: str = "tracking", published_at: datetime | None = None, window_end_at: datetime | None = None, ) -> AttributionRecord: now = datetime.now(UTC) return AttributionRecord( user_id=user_id, brand_id=brand_id, baseline_score=baseline_score, current_score=current_score, score_delta=score_delta, status=status, published_at=published_at or now, window_end_at=window_end_at or (now + timedelta(days=28)), ) class TestStartTrackingContract: @pytest.mark.asyncio async def test_start_creates_attribution_record( self, async_client, test_brand, test_diagnosis ): response = await async_client.post( "/api/v1/attribution/start", json={"brand_id": str(test_brand.id)}, ) assert response.status_code == 200 data = response.json() assert "id" in data assert data["brand_id"] == str(test_brand.id) assert data["baseline_score"] == 45.0 assert data["status"] == "tracking" assert data["content_id"] is None @pytest.mark.asyncio async def test_start_with_content_id( self, async_client, test_brand, test_diagnosis ): content_id = str(uuid.uuid4()) response = await async_client.post( "/api/v1/attribution/start", json={"brand_id": str(test_brand.id), "content_id": content_id}, ) assert response.status_code == 200 data = response.json() assert data["content_id"] == content_id @pytest.mark.asyncio async def test_start_with_invalid_brand_id(self, async_client): response = await async_client.post( "/api/v1/attribution/start", json={"brand_id": str(uuid.uuid4())}, ) assert response.status_code == 404 @pytest.mark.asyncio async def test_start_without_diagnosis_uses_zero_baseline( self, async_client, test_brand ): response = await async_client.post( "/api/v1/attribution/start", json={"brand_id": str(test_brand.id)}, ) assert response.status_code == 200 data = response.json() assert data["baseline_score"] == 0.0 class TestGetBrandAttributionContract: @pytest.mark.asyncio async def test_get_brand_attribution_summary( self, async_client, test_brand, test_user, test_diagnosis, async_session ): record = _make_attribution_record( user_id=test_user.id, brand_id=test_brand.id, current_score=55.0, score_delta=10.0, ) async_session.add(record) await async_session.commit() response = await async_client.get( f"/api/v1/attribution/brand/{test_brand.id}" ) assert response.status_code == 200 data = response.json() assert "records" in data assert "total_score_delta" in data assert "tracking_count" in data assert len(data["records"]) >= 1 class TestCheckAttributionContract: @pytest.mark.asyncio async def test_check_updates_attribution_record( self, async_client, test_brand, test_user, test_diagnosis, async_session ): record = _make_attribution_record( user_id=test_user.id, brand_id=test_brand.id, ) async_session.add(record) await async_session.commit() await async_session.refresh(record) response = await async_client.post( f"/api/v1/attribution/{record.id}/check" ) assert response.status_code == 200 data = response.json() assert data["current_score"] is not None assert data["score_delta"] is not None @pytest.mark.asyncio async def test_check_nonexistent_record(self, async_client): response = await async_client.post( f"/api/v1/attribution/{uuid.uuid4()}/check" ) assert response.status_code == 404 class TestGetROIReportContract: @pytest.mark.asyncio async def test_get_roi_report( self, async_client, test_brand, test_user, test_diagnosis, async_session ): record = _make_attribution_record( user_id=test_user.id, brand_id=test_brand.id, current_score=55.0, score_delta=10.0, ) async_session.add(record) await async_session.commit() response = await async_client.get( f"/api/v1/attribution/roi/{test_brand.id}" ) assert response.status_code == 200 data = response.json() assert "roi_percentage" in data assert "value_generated" in data assert "subscription_cost" in data assert "break_even_delta" in data assert "brand_name" in data assert "current_plan" in data assert "tracking_records" in data assert data["brand_name"] == "TestBrand" @pytest.mark.asyncio async def test_get_roi_invalid_brand(self, async_client): response = await async_client.get( f"/api/v1/attribution/roi/{uuid.uuid4()}" ) assert response.status_code == 404 class TestGetABComparisonContract: @pytest.mark.asyncio async def test_get_ab_comparison( self, async_client, test_brand, test_user, test_diagnosis, async_session ): later_diagnosis = DiagnosisRecord( brand_id=test_brand.id, user_id=_to_uuid(test_user.id), diagnosis_type="geo", status="completed", overall_score=65.0, result_json={ "overall_score": 65.0, "health_level": "good", "dimensions": [ {"name": "内容可提取性", "score": 70}, {"name": "E-E-A-T信号", "score": 60}, {"name": "引用就绪度", "score": 65}, ], }, completed_at=datetime.now(UTC) + timedelta(hours=1), ) async_session.add(later_diagnosis) await async_session.commit() response = await async_client.get( f"/api/v1/attribution/ab-comparison/{test_brand.id}" ) assert response.status_code == 200 data = response.json() assert "overall_before" in data assert "overall_after" in data assert "overall_delta" in data assert "dimensions" in data assert data["brand_name"] == "TestBrand" assert len(data["dimensions"]) > 0 @pytest.mark.asyncio async def test_get_ab_comparison_no_data(self, async_client, test_brand): response = await async_client.get( f"/api/v1/attribution/ab-comparison/{test_brand.id}" ) assert response.status_code == 404 class TestAttributionWindowExpiration: @pytest.mark.asyncio async def test_expired_window_marks_completed( self, async_client, test_brand, test_user, test_diagnosis, async_session ): record = _make_attribution_record( user_id=test_user.id, brand_id=test_brand.id, published_at=datetime.now(UTC) - timedelta(days=30), window_end_at=datetime.now(UTC) - timedelta(days=2), ) async_session.add(record) await async_session.commit() await async_session.refresh(record) response = await async_client.post( f"/api/v1/attribution/{record.id}/check" ) assert response.status_code == 200 data = response.json() assert data["status"] == "completed"