geo/backend/tests/test_api/test_attribution_contract.py

374 lines
11 KiB
Python

import uuid
from datetime import UTC, datetime, timedelta
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.api.deps import get_current_user, get_db
from app.database import Base
from app.main import app
from app.models.attribution_record import AttributionRecord
from app.models.brand import Brand
from app.models.diagnosis_record import DiagnosisRecord
from app.models.user import User
from app.services.auth import hash_password
from tests.fixtures.auth import _to_uuid
def _make_user(
user_id: str | None = None,
email: str = "test@example.com",
plan: str = "free",
) -> User:
uid = user_id or str(uuid.uuid4())
user = User(
id=uid,
email=email,
password=hash_password("Test@123456"),
firstName="Test",
lastName="User",
isActive=True,
emailVerified=True,
)
user.plan = plan
user.max_queries = 50 if plan != "free" else 5
return user
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
@pytest_asyncio.fixture
async def async_engine():
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
user = _make_user(plan="pro")
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_brand(async_session, test_user):
brand = Brand(
id=uuid.uuid4(),
user_id=_to_uuid(test_user.id),
name="TestBrand",
aliases=["TB"],
website="https://testbrand.com",
industry="technology",
platforms=["wenxin"],
frequency="weekly",
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
return brand
@pytest_asyncio.fixture
async def test_diagnosis(async_session, test_brand, test_user):
record = DiagnosisRecord(
brand_id=test_brand.id,
user_id=_to_uuid(test_user.id),
diagnosis_type="geo",
status="completed",
overall_score=45.0,
result_json={
"overall_score": 45.0,
"health_level": "pass",
"dimensions": [
{"name": "内容可提取性", "score": 50},
{"name": "E-E-A-T信号", "score": 40},
{"name": "引用就绪度", "score": 45},
],
},
completed_at=datetime.now(UTC),
)
async_session.add(record)
await async_session.commit()
await async_session.refresh(record)
return record
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
def _make_attribution_record(
user_id: str,
brand_id: uuid.UUID,
baseline_score: float = 45.0,
current_score: float | None = None,
score_delta: float | None = None,
status: str = "tracking",
published_at: datetime | None = None,
window_end_at: datetime | None = None,
) -> AttributionRecord:
now = datetime.now(UTC)
return AttributionRecord(
user_id=user_id,
brand_id=brand_id,
baseline_score=baseline_score,
current_score=current_score,
score_delta=score_delta,
status=status,
published_at=published_at or now,
window_end_at=window_end_at or (now + timedelta(days=28)),
)
class TestStartTrackingContract:
@pytest.mark.asyncio
async def test_start_creates_attribution_record(
self, async_client, test_brand, test_diagnosis
):
response = await async_client.post(
"/api/v1/attribution/start",
json={"brand_id": str(test_brand.id)},
)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["brand_id"] == str(test_brand.id)
assert data["baseline_score"] == 45.0
assert data["status"] == "tracking"
assert data["content_id"] is None
@pytest.mark.asyncio
async def test_start_with_content_id(
self, async_client, test_brand, test_diagnosis
):
content_id = str(uuid.uuid4())
response = await async_client.post(
"/api/v1/attribution/start",
json={"brand_id": str(test_brand.id), "content_id": content_id},
)
assert response.status_code == 200
data = response.json()
assert data["content_id"] == content_id
@pytest.mark.asyncio
async def test_start_with_invalid_brand_id(self, async_client):
response = await async_client.post(
"/api/v1/attribution/start",
json={"brand_id": str(uuid.uuid4())},
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_start_without_diagnosis_uses_zero_baseline(
self, async_client, test_brand
):
response = await async_client.post(
"/api/v1/attribution/start",
json={"brand_id": str(test_brand.id)},
)
assert response.status_code == 200
data = response.json()
assert data["baseline_score"] == 0.0
class TestGetBrandAttributionContract:
@pytest.mark.asyncio
async def test_get_brand_attribution_summary(
self, async_client, test_brand, test_user, test_diagnosis, async_session
):
record = _make_attribution_record(
user_id=test_user.id,
brand_id=test_brand.id,
current_score=55.0,
score_delta=10.0,
)
async_session.add(record)
await async_session.commit()
response = await async_client.get(
f"/api/v1/attribution/brand/{test_brand.id}"
)
assert response.status_code == 200
data = response.json()
assert "records" in data
assert "total_score_delta" in data
assert "tracking_count" in data
assert len(data["records"]) >= 1
class TestCheckAttributionContract:
@pytest.mark.asyncio
async def test_check_updates_attribution_record(
self, async_client, test_brand, test_user, test_diagnosis, async_session
):
record = _make_attribution_record(
user_id=test_user.id,
brand_id=test_brand.id,
)
async_session.add(record)
await async_session.commit()
await async_session.refresh(record)
response = await async_client.post(
f"/api/v1/attribution/{record.id}/check"
)
assert response.status_code == 200
data = response.json()
assert data["current_score"] is not None
assert data["score_delta"] is not None
@pytest.mark.asyncio
async def test_check_nonexistent_record(self, async_client):
response = await async_client.post(
f"/api/v1/attribution/{uuid.uuid4()}/check"
)
assert response.status_code == 404
class TestGetROIReportContract:
@pytest.mark.asyncio
async def test_get_roi_report(
self, async_client, test_brand, test_user, test_diagnosis, async_session
):
record = _make_attribution_record(
user_id=test_user.id,
brand_id=test_brand.id,
current_score=55.0,
score_delta=10.0,
)
async_session.add(record)
await async_session.commit()
response = await async_client.get(
f"/api/v1/attribution/roi/{test_brand.id}"
)
assert response.status_code == 200
data = response.json()
assert "roi_percentage" in data
assert "value_generated" in data
assert "subscription_cost" in data
assert "break_even_delta" in data
assert "brand_name" in data
assert "current_plan" in data
assert "tracking_records" in data
assert data["brand_name"] == "TestBrand"
@pytest.mark.asyncio
async def test_get_roi_invalid_brand(self, async_client):
response = await async_client.get(
f"/api/v1/attribution/roi/{uuid.uuid4()}"
)
assert response.status_code == 404
class TestGetABComparisonContract:
@pytest.mark.asyncio
async def test_get_ab_comparison(
self, async_client, test_brand, test_user, test_diagnosis, async_session
):
later_diagnosis = DiagnosisRecord(
brand_id=test_brand.id,
user_id=_to_uuid(test_user.id),
diagnosis_type="geo",
status="completed",
overall_score=65.0,
result_json={
"overall_score": 65.0,
"health_level": "good",
"dimensions": [
{"name": "内容可提取性", "score": 70},
{"name": "E-E-A-T信号", "score": 60},
{"name": "引用就绪度", "score": 65},
],
},
completed_at=datetime.now(UTC) + timedelta(hours=1),
)
async_session.add(later_diagnosis)
await async_session.commit()
response = await async_client.get(
f"/api/v1/attribution/ab-comparison/{test_brand.id}"
)
assert response.status_code == 200
data = response.json()
assert "overall_before" in data
assert "overall_after" in data
assert "overall_delta" in data
assert "dimensions" in data
assert data["brand_name"] == "TestBrand"
assert len(data["dimensions"]) > 0
@pytest.mark.asyncio
async def test_get_ab_comparison_no_data(self, async_client, test_brand):
response = await async_client.get(
f"/api/v1/attribution/ab-comparison/{test_brand.id}"
)
assert response.status_code == 404
class TestAttributionWindowExpiration:
@pytest.mark.asyncio
async def test_expired_window_marks_completed(
self, async_client, test_brand, test_user, test_diagnosis, async_session
):
record = _make_attribution_record(
user_id=test_user.id,
brand_id=test_brand.id,
published_at=datetime.now(UTC) - timedelta(days=30),
window_end_at=datetime.now(UTC) - timedelta(days=2),
)
async_session.add(record)
await async_session.commit()
await async_session.refresh(record)
response = await async_client.post(
f"/api/v1/attribution/{record.id}/check"
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "completed"