192 lines
7.4 KiB
Python
192 lines
7.4 KiB
Python
import uuid
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
|
|
from tests.fixtures.auth import _make_user, _to_uuid
|
|
|
|
|
|
class TestMonetizationFlow:
|
|
@pytest.mark.asyncio
|
|
async def test_full_monetization_flow(self, async_client, async_session):
|
|
user = _make_user(email="monetization@example.com", plan="free")
|
|
async_session.add(user)
|
|
await async_session.commit()
|
|
await async_session.refresh(user)
|
|
|
|
from app.api.deps import get_current_user, get_db
|
|
from app.main import app
|
|
|
|
async def override_get_db():
|
|
yield async_session
|
|
|
|
async def override_get_current_user():
|
|
return user
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
app.dependency_overrides[get_current_user] = override_get_current_user
|
|
|
|
try:
|
|
brand_resp = await async_client.post(
|
|
"/api/v1/onboarding/brand",
|
|
json={"name": "MonoBrand", "industry": "technology"},
|
|
)
|
|
assert brand_resp.status_code == 201
|
|
brand_data = brand_resp.json()
|
|
brand_id = brand_data["id"]
|
|
assert brand_data["name"] == "MonoBrand"
|
|
|
|
onboarding_resp = await async_client.post(
|
|
f"/api/v1/onboarding/complete/{brand_id}"
|
|
)
|
|
assert onboarding_resp.status_code == 200
|
|
assert onboarding_resp.json()["success"] is True
|
|
|
|
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
|
|
mock_collector = AsyncMock()
|
|
mock_collector_cls.return_value = mock_collector
|
|
|
|
from app.services.diagnosis.data_collector import DataCollectionResult
|
|
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService
|
|
|
|
input_data = GEODiagnosisInput(
|
|
has_direct_answer=True,
|
|
has_brand_definition=True,
|
|
has_author_bio=True,
|
|
author_credentials_complete=0.8,
|
|
has_organization=True,
|
|
content_depth_score=0.7,
|
|
answer_ownership_rate=0.3,
|
|
)
|
|
mock_collector.collect.return_value = DataCollectionResult(
|
|
diagnosis_input=input_data,
|
|
)
|
|
|
|
service = GEODiagnosisService()
|
|
result = service.diagnose(input_data)
|
|
|
|
with patch("app.api.onboarding.GEODiagnosisService", return_value=service):
|
|
health_resp = await async_client.get(
|
|
f"/api/v1/onboarding/health-report/{brand_id}"
|
|
)
|
|
|
|
assert health_resp.status_code == 200
|
|
health_data = health_resp.json()
|
|
assert health_data["brand_id"] == brand_id
|
|
assert "overall_score" in health_data
|
|
assert "dimensions" in health_data
|
|
assert health_data["is_full_report"] is False
|
|
|
|
order_resp = await async_client.post(
|
|
"/api/v1/payments/orders",
|
|
json={"plan": "pro", "payment_provider": "wechat"},
|
|
)
|
|
assert order_resp.status_code == 201
|
|
order_data = order_resp.json()
|
|
order_id = order_data["order_id"]
|
|
assert "pay_url" in order_data
|
|
assert order_data["amount"] == 599
|
|
assert order_data["status"] == "pending"
|
|
|
|
from app.models.payment_order import PaymentOrder as PaymentOrderModel
|
|
from app.services.payment.base import PaymentCallback
|
|
|
|
from sqlalchemy import select
|
|
stmt = select(PaymentOrderModel).where(
|
|
PaymentOrderModel.id == uuid.UUID(order_id)
|
|
)
|
|
db_result = await async_session.execute(stmt)
|
|
order = db_result.scalar_one()
|
|
|
|
callback = PaymentCallback(
|
|
order_id=str(order.id),
|
|
payment_id=f"wx_pay_{order.id}",
|
|
amount=599,
|
|
status="success",
|
|
raw_data={"out_trade_no": str(order.id), "result_code": "SUCCESS"},
|
|
)
|
|
|
|
from app.api.payments import _process_callback
|
|
await _process_callback(async_session, callback, "wechat")
|
|
await async_session.commit()
|
|
|
|
await async_session.refresh(order)
|
|
assert order.status == "paid"
|
|
assert order.payment_id is not None
|
|
|
|
await async_session.refresh(user)
|
|
assert user.plan == "pro"
|
|
assert user.max_queries == 50
|
|
|
|
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
|
|
mock_collector = AsyncMock()
|
|
mock_collector_cls.return_value = mock_collector
|
|
|
|
from app.services.diagnosis.data_collector import DataCollectionResult
|
|
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService
|
|
|
|
input_data = GEODiagnosisInput(
|
|
has_direct_answer=True,
|
|
has_brand_definition=True,
|
|
has_author_bio=True,
|
|
author_credentials_complete=0.8,
|
|
has_organization=True,
|
|
content_depth_score=0.7,
|
|
answer_ownership_rate=0.3,
|
|
)
|
|
mock_collector.collect.return_value = DataCollectionResult(
|
|
diagnosis_input=input_data,
|
|
)
|
|
|
|
service = GEODiagnosisService()
|
|
result = service.diagnose(input_data)
|
|
|
|
with patch("app.api.onboarding.GEODiagnosisService", return_value=service):
|
|
paid_health_resp = await async_client.get(
|
|
f"/api/v1/onboarding/health-report/{brand_id}"
|
|
)
|
|
|
|
assert paid_health_resp.status_code == 200
|
|
paid_health_data = paid_health_resp.json()
|
|
assert paid_health_data["is_full_report"] is True
|
|
assert len(paid_health_data["dimensions"]) == 6
|
|
|
|
from app.models.diagnosis_record import DiagnosisRecord
|
|
|
|
diag_record = DiagnosisRecord(
|
|
brand_id=uuid.UUID(brand_id),
|
|
user_id=_to_uuid(user.id),
|
|
diagnosis_type="geo",
|
|
status="completed",
|
|
overall_score=result.overall_score,
|
|
result_json=result.to_dict(),
|
|
)
|
|
async_session.add(diag_record)
|
|
await async_session.commit()
|
|
await async_session.refresh(diag_record)
|
|
|
|
attr_resp = await async_client.post(
|
|
"/api/v1/attribution/start",
|
|
json={"brand_id": brand_id},
|
|
)
|
|
assert attr_resp.status_code == 200
|
|
attr_data = attr_resp.json()
|
|
assert "id" in attr_data
|
|
assert attr_data["brand_id"] == brand_id
|
|
assert attr_data["baseline_score"] == result.overall_score
|
|
assert attr_data["status"] == "tracking"
|
|
|
|
roi_resp = await async_client.get(
|
|
f"/api/v1/attribution/roi/{brand_id}"
|
|
)
|
|
assert roi_resp.status_code == 200
|
|
roi_data = roi_resp.json()
|
|
assert "roi_percentage" in roi_data
|
|
assert "value_generated" in roi_data
|
|
assert "subscription_cost" in roi_data
|
|
assert roi_data["brand_name"] == "MonoBrand"
|
|
assert roi_data["current_plan"] == "pro"
|
|
|
|
finally:
|
|
app.dependency_overrides.clear()
|