301 lines
9.4 KiB
Python
301 lines
9.4 KiB
Python
import logging
|
||
import uuid
|
||
from datetime import datetime
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from pydantic import BaseModel
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.api.deps import get_current_user
|
||
from app.database import get_db
|
||
from app.models.attribution_record import AttributionRecord
|
||
from app.models.brand import Brand
|
||
from app.models.diagnosis_record import DiagnosisRecord
|
||
from app.models.user import User
|
||
from app.services.attribution.attribution_engine import AttributionEngine
|
||
from app.services.attribution.roi_calculator import ROICalculator
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter()
|
||
|
||
PLAN_COSTS = {
|
||
"free": 0.0,
|
||
"starter": 99.0,
|
||
"pro": 299.0,
|
||
"enterprise": 999.0,
|
||
}
|
||
|
||
|
||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||
if isinstance(value, uuid.UUID):
|
||
return value
|
||
return uuid.UUID(str(value))
|
||
|
||
|
||
class StartTrackingRequest(BaseModel):
|
||
brand_id: str
|
||
content_id: str | None = None
|
||
|
||
|
||
class AttributionResponse(BaseModel):
|
||
id: str
|
||
brand_id: str
|
||
content_id: str | None
|
||
baseline_score: float
|
||
current_score: float | None
|
||
score_delta: float | None
|
||
status: str
|
||
roi_percentage: float | None
|
||
created_at: datetime
|
||
|
||
model_config = {"from_attributes": True}
|
||
|
||
|
||
class ROIReport(BaseModel):
|
||
brand_id: str
|
||
brand_name: str
|
||
subscription_cost: float
|
||
current_plan: str
|
||
total_score_delta: float
|
||
value_generated: float
|
||
roi_percentage: float
|
||
break_even_delta: float
|
||
tracking_records: list[AttributionResponse]
|
||
ab_comparison: dict | None
|
||
|
||
|
||
class ABComparisonResponse(BaseModel):
|
||
brand_id: str
|
||
brand_name: str
|
||
overall_before: float
|
||
overall_after: float
|
||
overall_delta: float
|
||
dimensions: list[dict]
|
||
|
||
|
||
def _record_to_response(record: AttributionRecord) -> AttributionResponse:
|
||
return AttributionResponse(
|
||
id=str(record.id),
|
||
brand_id=str(record.brand_id),
|
||
content_id=str(record.content_id) if record.content_id else None,
|
||
baseline_score=record.baseline_score,
|
||
current_score=record.current_score,
|
||
score_delta=record.score_delta,
|
||
status=record.status,
|
||
roi_percentage=record.roi_percentage,
|
||
created_at=record.created_at,
|
||
)
|
||
|
||
|
||
async def _get_brand_or_404(
|
||
brand_id: uuid.UUID,
|
||
current_user: User,
|
||
db: AsyncSession,
|
||
) -> Brand:
|
||
user_uuid = _to_uuid(current_user.id)
|
||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid)
|
||
result = await db.execute(stmt)
|
||
brand = result.scalar_one_or_none()
|
||
if not brand:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="品牌不存在",
|
||
)
|
||
return brand
|
||
|
||
|
||
@router.post("/start", response_model=AttributionResponse)
|
||
async def start_tracking(
|
||
body: StartTrackingRequest,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
brand_id = _to_uuid(body.brand_id)
|
||
brand = await _get_brand_or_404(brand_id, current_user, db)
|
||
|
||
content_id = _to_uuid(body.content_id) if body.content_id else None
|
||
|
||
engine = AttributionEngine()
|
||
record = await engine.start_tracking(db, brand.id, content_id, current_user.id)
|
||
return _record_to_response(record)
|
||
|
||
|
||
@router.get("/brand/{brand_id}")
|
||
async def get_brand_attribution(
|
||
brand_id: uuid.UUID,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
brand = await _get_brand_or_404(brand_id, current_user, db)
|
||
|
||
engine = AttributionEngine()
|
||
summary = await engine.get_brand_attribution_summary(db, brand.id)
|
||
summary["records"] = [_record_to_response(r) for r in summary["records"]]
|
||
return summary
|
||
|
||
|
||
@router.get("/{record_id}", response_model=AttributionResponse)
|
||
async def get_attribution_record(
|
||
record_id: uuid.UUID,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
stmt = select(AttributionRecord).where(AttributionRecord.id == record_id)
|
||
result = await db.execute(stmt)
|
||
record = result.scalar_one_or_none()
|
||
if not record:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="归因记录不存在",
|
||
)
|
||
return _record_to_response(record)
|
||
|
||
|
||
@router.post("/{record_id}/check", response_model=AttributionResponse)
|
||
async def check_attribution(
|
||
record_id: uuid.UUID,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
engine = AttributionEngine()
|
||
try:
|
||
record = await engine.check_attribution(db, record_id)
|
||
except ValueError:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="归因记录不存在",
|
||
)
|
||
return _record_to_response(record)
|
||
|
||
|
||
@router.get("/roi/{brand_id}", response_model=ROIReport)
|
||
async def get_roi_report(
|
||
brand_id: uuid.UUID,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
brand = await _get_brand_or_404(brand_id, current_user, db)
|
||
|
||
engine = AttributionEngine()
|
||
summary = await engine.get_brand_attribution_summary(db, brand.id)
|
||
|
||
user_plan = getattr(current_user, "plan", "free") or "free"
|
||
subscription_cost = PLAN_COSTS.get(user_plan, 0.0)
|
||
|
||
calculator = ROICalculator()
|
||
roi_data = calculator.calculate_roi(
|
||
subscription_cost=subscription_cost,
|
||
score_delta=summary["total_score_delta"],
|
||
attribution_records=summary["records"],
|
||
)
|
||
|
||
ab_comparison = None
|
||
baseline_record = (
|
||
await db.execute(
|
||
select(DiagnosisRecord)
|
||
.where(
|
||
DiagnosisRecord.brand_id == brand.id,
|
||
DiagnosisRecord.status == "completed",
|
||
)
|
||
.order_by(DiagnosisRecord.completed_at.asc())
|
||
.limit(1)
|
||
)
|
||
).scalar_one_or_none()
|
||
latest_record = (
|
||
await db.execute(
|
||
select(DiagnosisRecord)
|
||
.where(
|
||
DiagnosisRecord.brand_id == brand.id,
|
||
DiagnosisRecord.status == "completed",
|
||
)
|
||
.order_by(DiagnosisRecord.completed_at.desc())
|
||
.limit(1)
|
||
)
|
||
).scalar_one_or_none()
|
||
|
||
if baseline_record and latest_record and baseline_record.id != latest_record.id:
|
||
before_dims = baseline_record.result_json.get("dimensions", []) if baseline_record.result_json else []
|
||
after_dims = latest_record.result_json.get("dimensions", []) if latest_record.result_json else []
|
||
before_map = {d.get("name"): d for d in before_dims}
|
||
after_map = {d.get("name"): d for d in after_dims}
|
||
ab_comparison = calculator.generate_ab_comparison(
|
||
before_score=baseline_record.overall_score or 0,
|
||
after_score=latest_record.overall_score or 0,
|
||
before_dimensions=before_map,
|
||
after_dimensions=after_map,
|
||
)
|
||
|
||
return ROIReport(
|
||
brand_id=str(brand.id),
|
||
brand_name=brand.name,
|
||
subscription_cost=subscription_cost,
|
||
current_plan=user_plan,
|
||
total_score_delta=summary["total_score_delta"],
|
||
value_generated=roi_data["value_generated"],
|
||
roi_percentage=roi_data["roi_percentage"],
|
||
break_even_delta=roi_data["break_even_delta"],
|
||
tracking_records=[_record_to_response(r) for r in summary["records"]],
|
||
ab_comparison=ab_comparison,
|
||
)
|
||
|
||
|
||
@router.get("/ab-comparison/{brand_id}", response_model=ABComparisonResponse)
|
||
async def get_ab_comparison(
|
||
brand_id: uuid.UUID,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
brand = await _get_brand_or_404(brand_id, current_user, db)
|
||
|
||
baseline_record = (
|
||
await db.execute(
|
||
select(DiagnosisRecord)
|
||
.where(
|
||
DiagnosisRecord.brand_id == brand.id,
|
||
DiagnosisRecord.status == "completed",
|
||
)
|
||
.order_by(DiagnosisRecord.completed_at.asc())
|
||
.limit(1)
|
||
)
|
||
).scalar_one_or_none()
|
||
latest_record = (
|
||
await db.execute(
|
||
select(DiagnosisRecord)
|
||
.where(
|
||
DiagnosisRecord.brand_id == brand.id,
|
||
DiagnosisRecord.status == "completed",
|
||
)
|
||
.order_by(DiagnosisRecord.completed_at.desc())
|
||
.limit(1)
|
||
)
|
||
).scalar_one_or_none()
|
||
|
||
if not baseline_record or not latest_record:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="暂无诊断数据,无法生成A/B对比",
|
||
)
|
||
|
||
calculator = ROICalculator()
|
||
before_dims = baseline_record.result_json.get("dimensions", []) if baseline_record.result_json else []
|
||
after_dims = latest_record.result_json.get("dimensions", []) if latest_record.result_json else []
|
||
before_map = {d.get("name"): d for d in before_dims}
|
||
after_map = {d.get("name"): d for d in after_dims}
|
||
comparison = calculator.generate_ab_comparison(
|
||
before_score=baseline_record.overall_score or 0,
|
||
after_score=latest_record.overall_score or 0,
|
||
before_dimensions=before_map,
|
||
after_dimensions=after_map,
|
||
)
|
||
|
||
return ABComparisonResponse(
|
||
brand_id=str(brand.id),
|
||
brand_name=brand.name,
|
||
overall_before=comparison["overall_before"],
|
||
overall_after=comparison["overall_after"],
|
||
overall_delta=comparison["overall_delta"],
|
||
dimensions=comparison["dimensions"],
|
||
)
|