geo/backend/tests/test_integration/test_monetization_flow.py

192 lines
7.4 KiB
Python

import uuid
from unittest.mock import AsyncMock, patch
import pytest
from tests.fixtures.auth import _make_user, _to_uuid
class TestMonetizationFlow:
@pytest.mark.asyncio
async def test_full_monetization_flow(self, async_client, async_session):
user = _make_user(email="monetization@example.com", plan="free")
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
from app.api.deps import get_current_user, get_db
from app.main import app
async def override_get_db():
yield async_session
async def override_get_current_user():
return user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
try:
brand_resp = await async_client.post(
"/api/v1/onboarding/brand",
json={"name": "MonoBrand", "industry": "technology"},
)
assert brand_resp.status_code == 201
brand_data = brand_resp.json()
brand_id = brand_data["id"]
assert brand_data["name"] == "MonoBrand"
onboarding_resp = await async_client.post(
f"/api/v1/onboarding/complete/{brand_id}"
)
assert onboarding_resp.status_code == 200
assert onboarding_resp.json()["success"] is True
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
mock_collector = AsyncMock()
mock_collector_cls.return_value = mock_collector
from app.services.diagnosis.data_collector import DataCollectionResult
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService
input_data = GEODiagnosisInput(
has_direct_answer=True,
has_brand_definition=True,
has_author_bio=True,
author_credentials_complete=0.8,
has_organization=True,
content_depth_score=0.7,
answer_ownership_rate=0.3,
)
mock_collector.collect.return_value = DataCollectionResult(
diagnosis_input=input_data,
)
service = GEODiagnosisService()
result = service.diagnose(input_data)
with patch("app.api.onboarding.GEODiagnosisService", return_value=service):
health_resp = await async_client.get(
f"/api/v1/onboarding/health-report/{brand_id}"
)
assert health_resp.status_code == 200
health_data = health_resp.json()
assert health_data["brand_id"] == brand_id
assert "overall_score" in health_data
assert "dimensions" in health_data
assert health_data["is_full_report"] is False
order_resp = await async_client.post(
"/api/v1/payments/orders",
json={"plan": "pro", "payment_provider": "wechat"},
)
assert order_resp.status_code == 201
order_data = order_resp.json()
order_id = order_data["order_id"]
assert "pay_url" in order_data
assert order_data["amount"] == 599
assert order_data["status"] == "pending"
from app.models.payment_order import PaymentOrder as PaymentOrderModel
from app.services.payment.base import PaymentCallback
from sqlalchemy import select
stmt = select(PaymentOrderModel).where(
PaymentOrderModel.id == uuid.UUID(order_id)
)
db_result = await async_session.execute(stmt)
order = db_result.scalar_one()
callback = PaymentCallback(
order_id=str(order.id),
payment_id=f"wx_pay_{order.id}",
amount=599,
status="success",
raw_data={"out_trade_no": str(order.id), "result_code": "SUCCESS"},
)
from app.api.payments import _process_callback
await _process_callback(async_session, callback, "wechat")
await async_session.commit()
await async_session.refresh(order)
assert order.status == "paid"
assert order.payment_id is not None
await async_session.refresh(user)
assert user.plan == "pro"
assert user.max_queries == 50
with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls:
mock_collector = AsyncMock()
mock_collector_cls.return_value = mock_collector
from app.services.diagnosis.data_collector import DataCollectionResult
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService
input_data = GEODiagnosisInput(
has_direct_answer=True,
has_brand_definition=True,
has_author_bio=True,
author_credentials_complete=0.8,
has_organization=True,
content_depth_score=0.7,
answer_ownership_rate=0.3,
)
mock_collector.collect.return_value = DataCollectionResult(
diagnosis_input=input_data,
)
service = GEODiagnosisService()
result = service.diagnose(input_data)
with patch("app.api.onboarding.GEODiagnosisService", return_value=service):
paid_health_resp = await async_client.get(
f"/api/v1/onboarding/health-report/{brand_id}"
)
assert paid_health_resp.status_code == 200
paid_health_data = paid_health_resp.json()
assert paid_health_data["is_full_report"] is True
assert len(paid_health_data["dimensions"]) == 6
from app.models.diagnosis_record import DiagnosisRecord
diag_record = DiagnosisRecord(
brand_id=uuid.UUID(brand_id),
user_id=_to_uuid(user.id),
diagnosis_type="geo",
status="completed",
overall_score=result.overall_score,
result_json=result.to_dict(),
)
async_session.add(diag_record)
await async_session.commit()
await async_session.refresh(diag_record)
attr_resp = await async_client.post(
"/api/v1/attribution/start",
json={"brand_id": brand_id},
)
assert attr_resp.status_code == 200
attr_data = attr_resp.json()
assert "id" in attr_data
assert attr_data["brand_id"] == brand_id
assert attr_data["baseline_score"] == result.overall_score
assert attr_data["status"] == "tracking"
roi_resp = await async_client.get(
f"/api/v1/attribution/roi/{brand_id}"
)
assert roi_resp.status_code == 200
roi_data = roi_resp.json()
assert "roi_percentage" in roi_data
assert "value_generated" in roi_data
assert "subscription_cost" in roi_data
assert roi_data["brand_name"] == "MonoBrand"
assert roi_data["current_plan"] == "pro"
finally:
app.dependency_overrides.clear()