151 lines
5.2 KiB
Python
151 lines
5.2 KiB
Python
import logging
|
|
import uuid
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.attribution_record import AttributionRecord
|
|
from app.models.diagnosis_record import DiagnosisRecord
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AttributionEngine:
|
|
async def start_tracking(
|
|
self,
|
|
db: AsyncSession,
|
|
brand_id: uuid.UUID,
|
|
content_id: uuid.UUID | None,
|
|
user_id: str,
|
|
) -> AttributionRecord:
|
|
baseline_score = await self._get_latest_score(db, brand_id)
|
|
|
|
now = datetime.now(UTC)
|
|
record = AttributionRecord(
|
|
user_id=user_id,
|
|
brand_id=brand_id,
|
|
content_id=content_id,
|
|
baseline_score=baseline_score,
|
|
published_at=now,
|
|
window_end_at=now + timedelta(days=28),
|
|
status="tracking",
|
|
)
|
|
db.add(record)
|
|
await db.commit()
|
|
await db.refresh(record)
|
|
return record
|
|
|
|
async def check_attribution(
|
|
self,
|
|
db: AsyncSession,
|
|
record_id: uuid.UUID,
|
|
) -> AttributionRecord:
|
|
stmt = select(AttributionRecord).where(AttributionRecord.id == record_id)
|
|
result = await db.execute(stmt)
|
|
record = result.scalar_one_or_none()
|
|
if not record:
|
|
raise ValueError(f"AttributionRecord {record_id} not found")
|
|
|
|
current_score = await self._get_latest_score(db, record.brand_id)
|
|
record.current_score = current_score
|
|
record.score_delta = round(current_score - record.baseline_score, 2)
|
|
|
|
baseline_dims = await self._get_latest_dimensions(db, record.brand_id, record.published_at)
|
|
current_dims = await self._get_latest_dimensions(db, record.brand_id, None)
|
|
if baseline_dims and current_dims:
|
|
record.attributed_dimensions = self._compute_dimension_deltas(
|
|
baseline_dims, current_dims
|
|
)
|
|
|
|
now = datetime.now(UTC)
|
|
if record.window_end_at:
|
|
window_end = record.window_end_at
|
|
if window_end.tzinfo is None:
|
|
window_end = window_end.replace(tzinfo=UTC)
|
|
if now >= window_end:
|
|
record.status = "completed"
|
|
elif record.score_delta and record.score_delta > 0:
|
|
record.status = "tracking"
|
|
|
|
await db.commit()
|
|
await db.refresh(record)
|
|
return record
|
|
|
|
async def get_brand_attribution_summary(
|
|
self,
|
|
db: AsyncSession,
|
|
brand_id: uuid.UUID,
|
|
) -> dict:
|
|
stmt = (
|
|
select(AttributionRecord)
|
|
.where(AttributionRecord.brand_id == brand_id)
|
|
.order_by(AttributionRecord.created_at.desc())
|
|
)
|
|
result = await db.execute(stmt)
|
|
records = result.scalars().all()
|
|
|
|
total_delta = sum(r.score_delta or 0 for r in records)
|
|
tracking_count = sum(1 for r in records if r.status == "tracking")
|
|
completed_count = sum(1 for r in records if r.status == "completed")
|
|
positive_count = sum(1 for r in records if (r.score_delta or 0) > 0)
|
|
|
|
return {
|
|
"brand_id": str(brand_id),
|
|
"records": records,
|
|
"total_score_delta": round(total_delta, 2),
|
|
"tracking_count": tracking_count,
|
|
"completed_count": completed_count,
|
|
"positive_count": positive_count,
|
|
}
|
|
|
|
async def _get_latest_score(self, db: AsyncSession, brand_id: uuid.UUID) -> float:
|
|
stmt = (
|
|
select(DiagnosisRecord)
|
|
.where(
|
|
DiagnosisRecord.brand_id == brand_id,
|
|
DiagnosisRecord.status == "completed",
|
|
)
|
|
.order_by(DiagnosisRecord.completed_at.desc())
|
|
.limit(1)
|
|
)
|
|
result = await db.execute(stmt)
|
|
record = result.scalar_one_or_none()
|
|
if record and record.overall_score is not None:
|
|
return float(record.overall_score)
|
|
logger.warning("No completed DiagnosisRecord for brand %s, using 0 as baseline", brand_id)
|
|
return 0.0
|
|
|
|
async def _get_latest_dimensions(
|
|
self,
|
|
db: AsyncSession,
|
|
brand_id: uuid.UUID,
|
|
before: datetime | None,
|
|
) -> dict | None:
|
|
stmt = (
|
|
select(DiagnosisRecord)
|
|
.where(
|
|
DiagnosisRecord.brand_id == brand_id,
|
|
DiagnosisRecord.status == "completed",
|
|
)
|
|
.order_by(DiagnosisRecord.completed_at.desc())
|
|
.limit(1)
|
|
)
|
|
if before:
|
|
stmt = stmt.where(DiagnosisRecord.completed_at < before)
|
|
result = await db.execute(stmt)
|
|
record = result.scalar_one_or_none()
|
|
if record and record.result_json:
|
|
return record.result_json.get("dimensions")
|
|
return None
|
|
|
|
def _compute_dimension_deltas(self, before_dims: list, after_dims: list) -> dict:
|
|
before_map = {d.get("name"): d.get("score", 0) for d in before_dims}
|
|
after_map = {d.get("name"): d.get("score", 0) for d in after_dims}
|
|
deltas = {}
|
|
for name in after_map:
|
|
b = before_map.get(name, 0)
|
|
a = after_map[name]
|
|
deltas[name] = {"before": b, "after": a, "delta": round(a - b, 2)}
|
|
return deltas
|