geo/backend/app/api/attribution.py

301 lines
9.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import uuid
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.attribution_record import AttributionRecord
from app.models.brand import Brand
from app.models.diagnosis_record import DiagnosisRecord
from app.models.user import User
from app.services.attribution.attribution_engine import AttributionEngine
from app.services.attribution.roi_calculator import ROICalculator
logger = logging.getLogger(__name__)
router = APIRouter()
PLAN_COSTS = {
"free": 0.0,
"starter": 99.0,
"pro": 299.0,
"enterprise": 999.0,
}
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
class StartTrackingRequest(BaseModel):
brand_id: str
content_id: str | None = None
class AttributionResponse(BaseModel):
id: str
brand_id: str
content_id: str | None
baseline_score: float
current_score: float | None
score_delta: float | None
status: str
roi_percentage: float | None
created_at: datetime
model_config = {"from_attributes": True}
class ROIReport(BaseModel):
brand_id: str
brand_name: str
subscription_cost: float
current_plan: str
total_score_delta: float
value_generated: float
roi_percentage: float
break_even_delta: float
tracking_records: list[AttributionResponse]
ab_comparison: dict | None
class ABComparisonResponse(BaseModel):
brand_id: str
brand_name: str
overall_before: float
overall_after: float
overall_delta: float
dimensions: list[dict]
def _record_to_response(record: AttributionRecord) -> AttributionResponse:
return AttributionResponse(
id=str(record.id),
brand_id=str(record.brand_id),
content_id=str(record.content_id) if record.content_id else None,
baseline_score=record.baseline_score,
current_score=record.current_score,
score_delta=record.score_delta,
status=record.status,
roi_percentage=record.roi_percentage,
created_at=record.created_at,
)
async def _get_brand_or_404(
brand_id: uuid.UUID,
current_user: User,
db: AsyncSession,
) -> Brand:
user_uuid = _to_uuid(current_user.id)
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
return brand
@router.post("/start", response_model=AttributionResponse)
async def start_tracking(
body: StartTrackingRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
brand_id = _to_uuid(body.brand_id)
brand = await _get_brand_or_404(brand_id, current_user, db)
content_id = _to_uuid(body.content_id) if body.content_id else None
engine = AttributionEngine()
record = await engine.start_tracking(db, brand.id, content_id, current_user.id)
return _record_to_response(record)
@router.get("/brand/{brand_id}")
async def get_brand_attribution(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
brand = await _get_brand_or_404(brand_id, current_user, db)
engine = AttributionEngine()
summary = await engine.get_brand_attribution_summary(db, brand.id)
summary["records"] = [_record_to_response(r) for r in summary["records"]]
return summary
@router.get("/{record_id}", response_model=AttributionResponse)
async def get_attribution_record(
record_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
stmt = select(AttributionRecord).where(AttributionRecord.id == record_id)
result = await db.execute(stmt)
record = result.scalar_one_or_none()
if not record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="归因记录不存在",
)
return _record_to_response(record)
@router.post("/{record_id}/check", response_model=AttributionResponse)
async def check_attribution(
record_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
engine = AttributionEngine()
try:
record = await engine.check_attribution(db, record_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="归因记录不存在",
)
return _record_to_response(record)
@router.get("/roi/{brand_id}", response_model=ROIReport)
async def get_roi_report(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
brand = await _get_brand_or_404(brand_id, current_user, db)
engine = AttributionEngine()
summary = await engine.get_brand_attribution_summary(db, brand.id)
user_plan = getattr(current_user, "plan", "free") or "free"
subscription_cost = PLAN_COSTS.get(user_plan, 0.0)
calculator = ROICalculator()
roi_data = calculator.calculate_roi(
subscription_cost=subscription_cost,
score_delta=summary["total_score_delta"],
attribution_records=summary["records"],
)
ab_comparison = None
baseline_record = (
await db.execute(
select(DiagnosisRecord)
.where(
DiagnosisRecord.brand_id == brand.id,
DiagnosisRecord.status == "completed",
)
.order_by(DiagnosisRecord.completed_at.asc())
.limit(1)
)
).scalar_one_or_none()
latest_record = (
await db.execute(
select(DiagnosisRecord)
.where(
DiagnosisRecord.brand_id == brand.id,
DiagnosisRecord.status == "completed",
)
.order_by(DiagnosisRecord.completed_at.desc())
.limit(1)
)
).scalar_one_or_none()
if baseline_record and latest_record and baseline_record.id != latest_record.id:
before_dims = baseline_record.result_json.get("dimensions", []) if baseline_record.result_json else []
after_dims = latest_record.result_json.get("dimensions", []) if latest_record.result_json else []
before_map = {d.get("name"): d for d in before_dims}
after_map = {d.get("name"): d for d in after_dims}
ab_comparison = calculator.generate_ab_comparison(
before_score=baseline_record.overall_score or 0,
after_score=latest_record.overall_score or 0,
before_dimensions=before_map,
after_dimensions=after_map,
)
return ROIReport(
brand_id=str(brand.id),
brand_name=brand.name,
subscription_cost=subscription_cost,
current_plan=user_plan,
total_score_delta=summary["total_score_delta"],
value_generated=roi_data["value_generated"],
roi_percentage=roi_data["roi_percentage"],
break_even_delta=roi_data["break_even_delta"],
tracking_records=[_record_to_response(r) for r in summary["records"]],
ab_comparison=ab_comparison,
)
@router.get("/ab-comparison/{brand_id}", response_model=ABComparisonResponse)
async def get_ab_comparison(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
brand = await _get_brand_or_404(brand_id, current_user, db)
baseline_record = (
await db.execute(
select(DiagnosisRecord)
.where(
DiagnosisRecord.brand_id == brand.id,
DiagnosisRecord.status == "completed",
)
.order_by(DiagnosisRecord.completed_at.asc())
.limit(1)
)
).scalar_one_or_none()
latest_record = (
await db.execute(
select(DiagnosisRecord)
.where(
DiagnosisRecord.brand_id == brand.id,
DiagnosisRecord.status == "completed",
)
.order_by(DiagnosisRecord.completed_at.desc())
.limit(1)
)
).scalar_one_or_none()
if not baseline_record or not latest_record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="暂无诊断数据无法生成A/B对比",
)
calculator = ROICalculator()
before_dims = baseline_record.result_json.get("dimensions", []) if baseline_record.result_json else []
after_dims = latest_record.result_json.get("dimensions", []) if latest_record.result_json else []
before_map = {d.get("name"): d for d in before_dims}
after_map = {d.get("name"): d for d in after_dims}
comparison = calculator.generate_ab_comparison(
before_score=baseline_record.overall_score or 0,
after_score=latest_record.overall_score or 0,
before_dimensions=before_map,
after_dimensions=after_map,
)
return ABComparisonResponse(
brand_id=str(brand.id),
brand_name=brand.name,
overall_before=comparison["overall_before"],
overall_after=comparison["overall_after"],
overall_delta=comparison["overall_delta"],
dimensions=comparison["dimensions"],
)