454 lines
15 KiB
Python
454 lines
15 KiB
Python
import uuid
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from httpx import ASGITransport, AsyncClient
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.api.deps import get_current_user, get_db
|
|
from app.database import Base
|
|
from app.main import app
|
|
from app.models.brand import Brand
|
|
from app.models.user import User
|
|
from app.services.auth import hash_password
|
|
|
|
|
|
def _make_user(
|
|
user_id: str | None = None,
|
|
email: str = "test@example.com",
|
|
plan: str = "free",
|
|
) -> User:
|
|
uid = user_id or str(uuid.uuid4())
|
|
user = User(
|
|
id=uid,
|
|
email=email,
|
|
password=hash_password("Test@123456"),
|
|
firstName="Test",
|
|
lastName="User",
|
|
isActive=True,
|
|
emailVerified=True,
|
|
)
|
|
user.plan = plan
|
|
user.max_queries = 50 if plan != "free" else 5
|
|
return user
|
|
|
|
|
|
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
|
if isinstance(value, uuid.UUID):
|
|
return value
|
|
return uuid.UUID(str(value))
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_engine():
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield engine
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_session(async_engine):
|
|
maker = async_sessionmaker(
|
|
async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
async with maker() as session:
|
|
yield session
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_user(async_session):
|
|
user = _make_user(plan="free")
|
|
async_session.add(user)
|
|
await async_session.commit()
|
|
await async_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def paid_user(async_session):
|
|
user = _make_user(email="paid@example.com", plan="pro")
|
|
async_session.add(user)
|
|
await async_session.commit()
|
|
await async_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_brand(async_session, test_user):
|
|
brand = Brand(
|
|
id=uuid.uuid4(),
|
|
user_id=_to_uuid(test_user.id),
|
|
name="TestBrand",
|
|
aliases=["TB"],
|
|
website="https://testbrand.com",
|
|
industry="technology",
|
|
platforms=["wenxin", "kimi"],
|
|
frequency="weekly",
|
|
status="active",
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
await async_session.refresh(brand)
|
|
return brand
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def paid_brand(async_session, paid_user):
|
|
brand = Brand(
|
|
id=uuid.uuid4(),
|
|
user_id=_to_uuid(paid_user.id),
|
|
name="PaidBrand",
|
|
aliases=["PB"],
|
|
website="https://paidbrand.com",
|
|
industry="technology",
|
|
platforms=["wenxin", "kimi"],
|
|
frequency="weekly",
|
|
status="active",
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
await async_session.refresh(brand)
|
|
return brand
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_client(async_session, test_user):
|
|
async def override_get_db():
|
|
yield async_session
|
|
|
|
async def override_get_current_user():
|
|
return test_user
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
app.dependency_overrides[get_current_user] = override_get_current_user
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def paid_client(async_session, paid_user):
|
|
async def override_get_db():
|
|
yield async_session
|
|
|
|
async def override_get_current_user():
|
|
return paid_user
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
app.dependency_overrides[get_current_user] = override_get_current_user
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
class TestOnboardingStatusContract:
|
|
@pytest.mark.asyncio
|
|
async def test_status_incomplete_no_brand(self, async_client):
|
|
response = await async_client.get("/api/v1/onboarding/status")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["completed"] is False
|
|
assert data["brand_id"] is None
|
|
assert data["current_step"] == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_status_complete_with_brand(self, async_client, test_brand):
|
|
response = await async_client.get("/api/v1/onboarding/status")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["completed"] is True
|
|
assert data["brand_id"] == str(test_brand.id)
|
|
|
|
|
|
class TestOnboardingHealthReportContract:
|
|
@pytest.mark.asyncio
|
|
async def test_health_report_returns_diagnosis_data(self, async_client, test_brand):
|
|
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
|
|
mock_collector = AsyncMock()
|
|
mock_collector_cls.return_value = mock_collector
|
|
|
|
from app.services.diagnosis.data_collector import DataCollectionResult
|
|
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService
|
|
|
|
input_data = GEODiagnosisInput(
|
|
has_direct_answer=True,
|
|
has_brand_definition=True,
|
|
answer_ownership_rate=0.3,
|
|
)
|
|
mock_collector.collect.return_value = DataCollectionResult(
|
|
diagnosis_input=input_data,
|
|
)
|
|
|
|
service = GEODiagnosisService()
|
|
result = service.diagnose(input_data)
|
|
|
|
with patch("app.api.onboarding.GEODiagnosisService", return_value=service):
|
|
response = await async_client.get(
|
|
f"/api/v1/onboarding/health-report/{test_brand.id}"
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["brand_id"] == str(test_brand.id)
|
|
assert data["brand_name"] == "TestBrand"
|
|
assert "overall_score" in data
|
|
assert "health_level" in data
|
|
assert "health_level_label" in data
|
|
assert "dimensions" in data
|
|
assert "recommendations" in data
|
|
assert "is_full_report" in data
|
|
assert isinstance(data["dimensions"], list)
|
|
assert isinstance(data["recommendations"], list)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_report_brand_not_found(self, async_client):
|
|
response = await async_client.get(
|
|
f"/api/v1/onboarding/health-report/{uuid.uuid4()}"
|
|
)
|
|
assert response.status_code == 404
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_report_free_user_3_dimensions(
|
|
self, async_client, test_brand
|
|
):
|
|
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
|
|
mock_collector = AsyncMock()
|
|
mock_collector_cls.return_value = mock_collector
|
|
|
|
from app.services.diagnosis.data_collector import DataCollectionResult
|
|
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService
|
|
|
|
input_data = GEODiagnosisInput(
|
|
has_direct_answer=True,
|
|
has_brand_definition=True,
|
|
has_author_bio=True,
|
|
author_credentials_complete=0.8,
|
|
has_organization=True,
|
|
content_depth_score=0.7,
|
|
answer_ownership_rate=0.3,
|
|
)
|
|
mock_collector.collect.return_value = DataCollectionResult(
|
|
diagnosis_input=input_data,
|
|
)
|
|
|
|
service = GEODiagnosisService()
|
|
result = service.diagnose(input_data)
|
|
|
|
with patch("app.api.onboarding.GEODiagnosisService", return_value=service):
|
|
response = await async_client.get(
|
|
f"/api/v1/onboarding/health-report/{test_brand.id}"
|
|
)
|
|
|
|
data = response.json()
|
|
assert data["is_full_report"] is False
|
|
dim_names = {d["name"] for d in data["dimensions"]}
|
|
assert len(dim_names) == 3
|
|
assert "内容可提取性" in dim_names
|
|
assert "E-E-A-T信号" in dim_names
|
|
assert "引用就绪度" in dim_names
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_report_paid_user_6_dimensions(
|
|
self, paid_client, paid_brand
|
|
):
|
|
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
|
|
mock_collector = AsyncMock()
|
|
mock_collector_cls.return_value = mock_collector
|
|
|
|
from app.services.diagnosis.data_collector import DataCollectionResult
|
|
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService
|
|
|
|
input_data = GEODiagnosisInput(
|
|
has_direct_answer=True,
|
|
has_brand_definition=True,
|
|
has_author_bio=True,
|
|
author_credentials_complete=0.8,
|
|
has_organization=True,
|
|
content_depth_score=0.7,
|
|
answer_ownership_rate=0.3,
|
|
)
|
|
mock_collector.collect.return_value = DataCollectionResult(
|
|
diagnosis_input=input_data,
|
|
)
|
|
|
|
service = GEODiagnosisService()
|
|
result = service.diagnose(input_data)
|
|
|
|
with patch("app.api.onboarding.GEODiagnosisService", return_value=service):
|
|
response = await paid_client.get(
|
|
f"/api/v1/onboarding/health-report/{paid_brand.id}"
|
|
)
|
|
|
|
data = response.json()
|
|
assert data["is_full_report"] is True
|
|
assert len(data["dimensions"]) == 6
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_report_collection_failure_fallback(
|
|
self, async_client, test_brand
|
|
):
|
|
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
|
|
mock_collector = AsyncMock()
|
|
mock_collector_cls.return_value = mock_collector
|
|
mock_collector.collect.side_effect = Exception("DB error")
|
|
|
|
response = await async_client.get(
|
|
f"/api/v1/onboarding/health-report/{test_brand.id}"
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["overall_score"] == 0
|
|
assert data["health_level"] == "danger"
|
|
assert len(data["dimensions"]) == 3
|
|
|
|
|
|
class TestOnboardingActionSuggestionsContract:
|
|
@pytest.mark.asyncio
|
|
async def test_suggestions_have_paid_action_field(
|
|
self, async_client, test_brand
|
|
):
|
|
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
|
|
mock_collector = AsyncMock()
|
|
mock_collector_cls.return_value = mock_collector
|
|
|
|
from app.services.diagnosis.data_collector import DataCollectionResult
|
|
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService
|
|
|
|
input_data = GEODiagnosisInput(
|
|
has_direct_answer=True,
|
|
has_brand_definition=True,
|
|
answer_ownership_rate=0.3,
|
|
)
|
|
mock_collector.collect.return_value = DataCollectionResult(
|
|
diagnosis_input=input_data,
|
|
)
|
|
|
|
service = GEODiagnosisService()
|
|
result = service.diagnose(input_data)
|
|
|
|
with patch("app.api.onboarding.GEODiagnosisService", return_value=service):
|
|
response = await async_client.get(
|
|
f"/api/v1/onboarding/action-suggestions/{test_brand.id}"
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "suggestions" in data
|
|
assert isinstance(data["suggestions"], list)
|
|
for s in data["suggestions"]:
|
|
assert "is_paid_action" in s
|
|
assert "action_button_text" in s
|
|
assert isinstance(s["is_paid_action"], bool)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_free_user_gets_upgrade_suggestion(
|
|
self, async_client, test_brand
|
|
):
|
|
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
|
|
mock_collector = AsyncMock()
|
|
mock_collector_cls.return_value = mock_collector
|
|
|
|
from app.services.diagnosis.data_collector import DataCollectionResult
|
|
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService
|
|
|
|
input_data = GEODiagnosisInput(
|
|
has_direct_answer=True,
|
|
has_brand_definition=True,
|
|
answer_ownership_rate=0.3,
|
|
)
|
|
mock_collector.collect.return_value = DataCollectionResult(
|
|
diagnosis_input=input_data,
|
|
)
|
|
|
|
service = GEODiagnosisService()
|
|
result = service.diagnose(input_data)
|
|
|
|
with patch("app.api.onboarding.GEODiagnosisService", return_value=service):
|
|
response = await async_client.get(
|
|
f"/api/v1/onboarding/action-suggestions/{test_brand.id}"
|
|
)
|
|
|
|
data = response.json()
|
|
upgrade_suggestions = [
|
|
s for s in data["suggestions"] if s["action_type"] == "upgrade"
|
|
]
|
|
assert len(upgrade_suggestions) >= 1
|
|
assert any(s["is_paid_action"] for s in upgrade_suggestions)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_zero_score_suggestions_include_upgrade(
|
|
self, async_client, test_brand
|
|
):
|
|
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
|
|
mock_collector = AsyncMock()
|
|
mock_collector_cls.return_value = mock_collector
|
|
mock_collector.collect.side_effect = Exception("No data")
|
|
|
|
response = await async_client.get(
|
|
f"/api/v1/onboarding/action-suggestions/{test_brand.id}"
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
upgrade_suggestions = [
|
|
s for s in data["suggestions"] if s["action_type"] == "upgrade"
|
|
]
|
|
assert len(upgrade_suggestions) >= 1
|
|
|
|
|
|
class TestOnboardingCreateBrandContract:
|
|
@pytest.mark.asyncio
|
|
async def test_create_brand_success(self, async_client):
|
|
response = await async_client.post(
|
|
"/api/v1/onboarding/brand",
|
|
json={"name": "NewBrand", "industry": "tech"},
|
|
)
|
|
assert response.status_code == 201
|
|
data = response.json()
|
|
assert data["name"] == "NewBrand"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_brand_short_name_rejected(self, async_client):
|
|
response = await async_client.post(
|
|
"/api/v1/onboarding/brand",
|
|
json={"name": "A"},
|
|
)
|
|
assert response.status_code == 422
|
|
|
|
|
|
class TestOnboardingCompleteContract:
|
|
@pytest.mark.asyncio
|
|
async def test_complete_onboarding(self, async_client, test_brand):
|
|
response = await async_client.post(
|
|
f"/api/v1/onboarding/complete/{test_brand.id}"
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_onboarding_brand_not_found(self, async_client):
|
|
response = await async_client.post(
|
|
f"/api/v1/onboarding/complete/{uuid.uuid4()}"
|
|
)
|
|
assert response.status_code == 404
|