geo/backend/tests/test_api/test_onboarding_contract.py

454 lines
15 KiB
Python

import uuid
from unittest.mock import AsyncMock, patch
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.api.deps import get_current_user, get_db
from app.database import Base
from app.main import app
from app.models.brand import Brand
from app.models.user import User
from app.services.auth import hash_password
def _make_user(
user_id: str | None = None,
email: str = "test@example.com",
plan: str = "free",
) -> User:
uid = user_id or str(uuid.uuid4())
user = User(
id=uid,
email=email,
password=hash_password("Test@123456"),
firstName="Test",
lastName="User",
isActive=True,
emailVerified=True,
)
user.plan = plan
user.max_queries = 50 if plan != "free" else 5
return user
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
@pytest_asyncio.fixture
async def async_engine():
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
user = _make_user(plan="free")
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def paid_user(async_session):
user = _make_user(email="paid@example.com", plan="pro")
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_brand(async_session, test_user):
brand = Brand(
id=uuid.uuid4(),
user_id=_to_uuid(test_user.id),
name="TestBrand",
aliases=["TB"],
website="https://testbrand.com",
industry="technology",
platforms=["wenxin", "kimi"],
frequency="weekly",
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
return brand
@pytest_asyncio.fixture
async def paid_brand(async_session, paid_user):
brand = Brand(
id=uuid.uuid4(),
user_id=_to_uuid(paid_user.id),
name="PaidBrand",
aliases=["PB"],
website="https://paidbrand.com",
industry="technology",
platforms=["wenxin", "kimi"],
frequency="weekly",
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
return brand
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest_asyncio.fixture
async def paid_client(async_session, paid_user):
async def override_get_db():
yield async_session
async def override_get_current_user():
return paid_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
class TestOnboardingStatusContract:
@pytest.mark.asyncio
async def test_status_incomplete_no_brand(self, async_client):
response = await async_client.get("/api/v1/onboarding/status")
assert response.status_code == 200
data = response.json()
assert data["completed"] is False
assert data["brand_id"] is None
assert data["current_step"] == 1
@pytest.mark.asyncio
async def test_status_complete_with_brand(self, async_client, test_brand):
response = await async_client.get("/api/v1/onboarding/status")
assert response.status_code == 200
data = response.json()
assert data["completed"] is True
assert data["brand_id"] == str(test_brand.id)
class TestOnboardingHealthReportContract:
@pytest.mark.asyncio
async def test_health_report_returns_diagnosis_data(self, async_client, test_brand):
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
mock_collector = AsyncMock()
mock_collector_cls.return_value = mock_collector
from app.services.diagnosis.data_collector import DataCollectionResult
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService
input_data = GEODiagnosisInput(
has_direct_answer=True,
has_brand_definition=True,
answer_ownership_rate=0.3,
)
mock_collector.collect.return_value = DataCollectionResult(
diagnosis_input=input_data,
)
service = GEODiagnosisService()
result = service.diagnose(input_data)
with patch("app.api.onboarding.GEODiagnosisService", return_value=service):
response = await async_client.get(
f"/api/v1/onboarding/health-report/{test_brand.id}"
)
assert response.status_code == 200
data = response.json()
assert data["brand_id"] == str(test_brand.id)
assert data["brand_name"] == "TestBrand"
assert "overall_score" in data
assert "health_level" in data
assert "health_level_label" in data
assert "dimensions" in data
assert "recommendations" in data
assert "is_full_report" in data
assert isinstance(data["dimensions"], list)
assert isinstance(data["recommendations"], list)
@pytest.mark.asyncio
async def test_health_report_brand_not_found(self, async_client):
response = await async_client.get(
f"/api/v1/onboarding/health-report/{uuid.uuid4()}"
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_health_report_free_user_3_dimensions(
self, async_client, test_brand
):
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
mock_collector = AsyncMock()
mock_collector_cls.return_value = mock_collector
from app.services.diagnosis.data_collector import DataCollectionResult
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService
input_data = GEODiagnosisInput(
has_direct_answer=True,
has_brand_definition=True,
has_author_bio=True,
author_credentials_complete=0.8,
has_organization=True,
content_depth_score=0.7,
answer_ownership_rate=0.3,
)
mock_collector.collect.return_value = DataCollectionResult(
diagnosis_input=input_data,
)
service = GEODiagnosisService()
result = service.diagnose(input_data)
with patch("app.api.onboarding.GEODiagnosisService", return_value=service):
response = await async_client.get(
f"/api/v1/onboarding/health-report/{test_brand.id}"
)
data = response.json()
assert data["is_full_report"] is False
dim_names = {d["name"] for d in data["dimensions"]}
assert len(dim_names) == 3
assert "内容可提取性" in dim_names
assert "E-E-A-T信号" in dim_names
assert "引用就绪度" in dim_names
@pytest.mark.asyncio
async def test_health_report_paid_user_6_dimensions(
self, paid_client, paid_brand
):
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
mock_collector = AsyncMock()
mock_collector_cls.return_value = mock_collector
from app.services.diagnosis.data_collector import DataCollectionResult
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService
input_data = GEODiagnosisInput(
has_direct_answer=True,
has_brand_definition=True,
has_author_bio=True,
author_credentials_complete=0.8,
has_organization=True,
content_depth_score=0.7,
answer_ownership_rate=0.3,
)
mock_collector.collect.return_value = DataCollectionResult(
diagnosis_input=input_data,
)
service = GEODiagnosisService()
result = service.diagnose(input_data)
with patch("app.api.onboarding.GEODiagnosisService", return_value=service):
response = await paid_client.get(
f"/api/v1/onboarding/health-report/{paid_brand.id}"
)
data = response.json()
assert data["is_full_report"] is True
assert len(data["dimensions"]) == 6
@pytest.mark.asyncio
async def test_health_report_collection_failure_fallback(
self, async_client, test_brand
):
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
mock_collector = AsyncMock()
mock_collector_cls.return_value = mock_collector
mock_collector.collect.side_effect = Exception("DB error")
response = await async_client.get(
f"/api/v1/onboarding/health-report/{test_brand.id}"
)
assert response.status_code == 200
data = response.json()
assert data["overall_score"] == 0
assert data["health_level"] == "danger"
assert len(data["dimensions"]) == 3
class TestOnboardingActionSuggestionsContract:
@pytest.mark.asyncio
async def test_suggestions_have_paid_action_field(
self, async_client, test_brand
):
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
mock_collector = AsyncMock()
mock_collector_cls.return_value = mock_collector
from app.services.diagnosis.data_collector import DataCollectionResult
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService
input_data = GEODiagnosisInput(
has_direct_answer=True,
has_brand_definition=True,
answer_ownership_rate=0.3,
)
mock_collector.collect.return_value = DataCollectionResult(
diagnosis_input=input_data,
)
service = GEODiagnosisService()
result = service.diagnose(input_data)
with patch("app.api.onboarding.GEODiagnosisService", return_value=service):
response = await async_client.get(
f"/api/v1/onboarding/action-suggestions/{test_brand.id}"
)
assert response.status_code == 200
data = response.json()
assert "suggestions" in data
assert isinstance(data["suggestions"], list)
for s in data["suggestions"]:
assert "is_paid_action" in s
assert "action_button_text" in s
assert isinstance(s["is_paid_action"], bool)
@pytest.mark.asyncio
async def test_free_user_gets_upgrade_suggestion(
self, async_client, test_brand
):
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
mock_collector = AsyncMock()
mock_collector_cls.return_value = mock_collector
from app.services.diagnosis.data_collector import DataCollectionResult
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService
input_data = GEODiagnosisInput(
has_direct_answer=True,
has_brand_definition=True,
answer_ownership_rate=0.3,
)
mock_collector.collect.return_value = DataCollectionResult(
diagnosis_input=input_data,
)
service = GEODiagnosisService()
result = service.diagnose(input_data)
with patch("app.api.onboarding.GEODiagnosisService", return_value=service):
response = await async_client.get(
f"/api/v1/onboarding/action-suggestions/{test_brand.id}"
)
data = response.json()
upgrade_suggestions = [
s for s in data["suggestions"] if s["action_type"] == "upgrade"
]
assert len(upgrade_suggestions) >= 1
assert any(s["is_paid_action"] for s in upgrade_suggestions)
@pytest.mark.asyncio
async def test_zero_score_suggestions_include_upgrade(
self, async_client, test_brand
):
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
mock_collector = AsyncMock()
mock_collector_cls.return_value = mock_collector
mock_collector.collect.side_effect = Exception("No data")
response = await async_client.get(
f"/api/v1/onboarding/action-suggestions/{test_brand.id}"
)
assert response.status_code == 200
data = response.json()
upgrade_suggestions = [
s for s in data["suggestions"] if s["action_type"] == "upgrade"
]
assert len(upgrade_suggestions) >= 1
class TestOnboardingCreateBrandContract:
@pytest.mark.asyncio
async def test_create_brand_success(self, async_client):
response = await async_client.post(
"/api/v1/onboarding/brand",
json={"name": "NewBrand", "industry": "tech"},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "NewBrand"
@pytest.mark.asyncio
async def test_create_brand_short_name_rejected(self, async_client):
response = await async_client.post(
"/api/v1/onboarding/brand",
json={"name": "A"},
)
assert response.status_code == 422
class TestOnboardingCompleteContract:
@pytest.mark.asyncio
async def test_complete_onboarding(self, async_client, test_brand):
response = await async_client.post(
f"/api/v1/onboarding/complete/{test_brand.id}"
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
@pytest.mark.asyncio
async def test_complete_onboarding_brand_not_found(self, async_client):
response = await async_client.post(
f"/api/v1/onboarding/complete/{uuid.uuid4()}"
)
assert response.status_code == 404