geo/backend/tests/test_api/test_diagnosis_contract.py

400 lines
12 KiB
Python

import uuid
from unittest.mock import AsyncMock, patch
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.api.deps import get_current_user, get_db
from app.database import Base
from app.main import app
from app.models.brand import Brand
from app.models.diagnosis_record import DiagnosisRecord
from app.models.user import User
from app.services.auth import hash_password
from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput
def _make_user(
user_id: str | None = None,
email: str = "test@example.com",
plan: str = "free",
) -> User:
uid = user_id or str(uuid.uuid4())
user = User(
id=uid,
email=email,
password=hash_password("Test@123456"),
firstName="Test",
lastName="User",
isActive=True,
emailVerified=True,
)
user.plan = plan
user.max_queries = 50 if plan != "free" else 5
return user
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
@pytest_asyncio.fixture
async def async_engine():
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
user = _make_user(plan="free")
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def paid_user(async_session):
user = _make_user(email="paid@example.com", plan="pro")
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_brand(async_session, test_user):
brand = Brand(
id=uuid.uuid4(),
user_id=_to_uuid(test_user.id),
name="TestBrand",
aliases=["TB", "Test Brand"],
website="https://testbrand.com",
industry="technology",
platforms=["wenxin", "kimi"],
frequency="weekly",
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
return brand
@pytest_asyncio.fixture
async def paid_brand(async_session, paid_user):
brand = Brand(
id=uuid.uuid4(),
user_id=_to_uuid(paid_user.id),
name="PaidBrand",
aliases=["PB"],
website="https://paidbrand.com",
industry="technology",
platforms=["wenxin", "kimi"],
frequency="weekly",
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
return brand
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest_asyncio.fixture
async def paid_client(async_session, paid_user):
async def override_get_db():
yield async_session
async def override_get_current_user():
return paid_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
class TestGEODiagnosisTriggerContract:
@pytest.mark.asyncio
async def test_trigger_returns_202(self, async_client, test_brand):
with patch("app.api.diagnosis._run_geo_diagnosis", new_callable=AsyncMock):
response = await async_client.post(
f"/api/v1/diagnosis/geo/{test_brand.id}"
)
assert response.status_code == 202
data = response.json()
assert "task_id" in data
assert "brand_id" in data
assert "status" in data
assert data["status"] in ("pending", "completed")
assert data["brand_id"] == str(test_brand.id)
@pytest.mark.asyncio
async def test_trigger_brand_not_found(self, async_client):
response = await async_client.post(
f"/api/v1/diagnosis/geo/{uuid.uuid4()}"
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_trigger_force_refresh(self, async_client, test_brand):
with patch("app.api.diagnosis._run_geo_diagnosis", new_callable=AsyncMock):
response = await async_client.post(
f"/api/v1/diagnosis/geo/{test_brand.id}",
json={"force_refresh": True},
)
assert response.status_code == 202
class TestGEODiagnosisResultContract:
@pytest.mark.asyncio
async def test_result_pending(self, async_client, test_brand, async_session):
record = DiagnosisRecord(
brand_id=test_brand.id,
user_id=_to_uuid(test_brand.user_id),
diagnosis_type="geo",
status="pending",
)
async_session.add(record)
await async_session.commit()
await async_session.refresh(record)
response = await async_client.get(
f"/api/v1/diagnosis/geo/{test_brand.id}/result",
params={"task_id": str(record.id)},
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "pending"
assert data["result"] is None
@pytest.mark.asyncio
async def test_result_completed_nonzero_score(
self, async_client, test_brand, async_session
):
from app.services.diagnosis.geo_diagnosis import GEODiagnosisService
input_data = GEODiagnosisInput(
has_direct_answer=True,
has_brand_definition=True,
has_author_bio=True,
author_credentials_complete=0.8,
answer_ownership_rate=0.3,
)
service = GEODiagnosisService()
result = service.diagnose(input_data)
record = DiagnosisRecord(
brand_id=test_brand.id,
user_id=_to_uuid(test_brand.user_id),
diagnosis_type="geo",
status="completed",
overall_score=result.overall_score,
result_json=result.to_dict(),
)
async_session.add(record)
await async_session.commit()
await async_session.refresh(record)
response = await async_client.get(
f"/api/v1/diagnosis/geo/{test_brand.id}/result",
params={"task_id": str(record.id)},
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "completed"
assert data["result"] is not None
assert data["result"]["overall_score"] > 0
@pytest.mark.asyncio
async def test_result_not_found(self, async_client, test_brand):
response = await async_client.get(
f"/api/v1/diagnosis/geo/{test_brand.id}/result",
params={"task_id": str(uuid.uuid4())},
)
assert response.status_code == 404
class TestGEODiagnosisFreeVsPaidContract:
@pytest.mark.asyncio
async def test_free_user_gets_3_dimensions(
self, async_client, test_brand, async_session
):
from app.services.diagnosis.geo_diagnosis import GEODiagnosisService
input_data = GEODiagnosisInput(
has_direct_answer=True,
has_brand_definition=True,
has_author_bio=True,
author_credentials_complete=0.8,
has_organization=True,
content_depth_score=0.7,
answer_ownership_rate=0.3,
)
service = GEODiagnosisService()
result = service.diagnose(input_data)
record = DiagnosisRecord(
brand_id=test_brand.id,
user_id=_to_uuid(test_brand.user_id),
diagnosis_type="geo",
status="completed",
overall_score=result.overall_score,
result_json=result.to_dict(),
)
async_session.add(record)
await async_session.commit()
await async_session.refresh(record)
response = await async_client.get(
f"/api/v1/diagnosis/geo/{test_brand.id}/result",
params={"task_id": str(record.id)},
)
data = response.json()
assert data["result"]["is_full_report"] is False
dim_names = {d["name"] for d in data["result"]["dimensions"]}
assert len(dim_names) == 3
assert "内容可提取性" in dim_names
assert "E-E-A-T信号" in dim_names
assert "引用就绪度" in dim_names
@pytest.mark.asyncio
async def test_paid_user_gets_6_dimensions(
self, paid_client, paid_brand, async_session, paid_user
):
from app.services.diagnosis.geo_diagnosis import GEODiagnosisService
input_data = GEODiagnosisInput(
has_direct_answer=True,
has_brand_definition=True,
has_author_bio=True,
author_credentials_complete=0.8,
has_organization=True,
content_depth_score=0.7,
answer_ownership_rate=0.3,
)
service = GEODiagnosisService()
result = service.diagnose(input_data)
record = DiagnosisRecord(
brand_id=paid_brand.id,
user_id=_to_uuid(paid_user.id),
diagnosis_type="geo",
status="completed",
overall_score=result.overall_score,
result_json=result.to_dict(),
)
async_session.add(record)
await async_session.commit()
await async_session.refresh(record)
response = await paid_client.get(
f"/api/v1/diagnosis/geo/{paid_brand.id}/result",
params={"task_id": str(record.id)},
)
data = response.json()
assert data["result"]["is_full_report"] is True
assert len(data["result"]["dimensions"]) == 6
class TestGEODiagnosisHistoryContract:
@pytest.mark.asyncio
async def test_history_returns_list(
self, async_client, test_brand, async_session
):
record = DiagnosisRecord(
brand_id=test_brand.id,
user_id=_to_uuid(test_brand.user_id),
diagnosis_type="geo",
status="completed",
overall_score=45.0,
result_json={
"overall_score": 45.0,
"health_level": "pass",
},
)
async_session.add(record)
await async_session.commit()
response = await async_client.get(
f"/api/v1/diagnosis/geo/{test_brand.id}/history"
)
assert response.status_code == 200
data = response.json()
assert "brand_id" in data
assert "history" in data
assert isinstance(data["history"], list)
class TestGEODiagnosisDataCollectionContract:
@pytest.mark.asyncio
async def test_diagnosis_with_data_collection_produces_nonzero(
self, async_client, test_brand
):
with patch("app.api.diagnosis._run_geo_diagnosis", new_callable=AsyncMock):
response = await async_client.post(
f"/api/v1/diagnosis/geo/{test_brand.id}",
json={"force_refresh": True},
)
assert response.status_code == 202
@pytest.mark.asyncio
async def test_combined_diagnosis_uses_data_collector(
self, async_client, test_brand
):
response = await async_client.get(
f"/api/v1/diagnosis/combined/{test_brand.id}"
)
assert response.status_code == 200
data = response.json()
assert "geo_score" in data
assert "seo_score" in data
assert "combined_score" in data