382 lines
11 KiB
Python
382 lines
11 KiB
Python
"""Tests for scoring API endpoints."""
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from httpx import AsyncClient, ASGITransport
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
|
|
|
|
from app.database import Base
|
|
from app.main import app
|
|
from app.models.user import User
|
|
from app.models.brand import Brand
|
|
from app.models.competitor import Competitor
|
|
from app.models.query import Query
|
|
from app.models.citation_record import CitationRecord
|
|
from app.api.deps import get_current_user, get_db
|
|
from app.services.auth import create_access_token
|
|
from tests.fixtures.auth import _to_uuid
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_engine():
|
|
"""Create async engine for testing with SQLite."""
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield engine
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_session(async_engine):
|
|
"""Create async session for testing."""
|
|
async_session_maker = async_sessionmaker(
|
|
async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
async with async_session_maker() as session:
|
|
yield session
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_user(async_session):
|
|
"""Create a test user."""
|
|
user = User(
|
|
id=str(uuid.uuid4()),
|
|
email="test@example.com",
|
|
password="hashed_password",
|
|
firstName="Test User",
|
|
plan="free",
|
|
max_queries=5,
|
|
isActive=True,
|
|
emailVerified=True,
|
|
)
|
|
async_session.add(user)
|
|
await async_session.commit()
|
|
await async_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_brand(async_session, test_user):
|
|
"""Create a test brand."""
|
|
brand = Brand(
|
|
id=uuid.uuid4(),
|
|
user_id=_to_uuid(test_user.id),
|
|
name="TestBrand",
|
|
aliases=["TestBrand", "TB"],
|
|
website="https://testbrand.com",
|
|
industry="technology",
|
|
platforms=["wenxin", "kimi"],
|
|
frequency="weekly",
|
|
status="active",
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
await async_session.refresh(brand)
|
|
return brand
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_brand_with_data(async_session: AsyncSession, test_user):
|
|
"""Create a test brand with query and citation data."""
|
|
brand = Brand(
|
|
id=uuid.uuid4(),
|
|
user_id=_to_uuid(test_user.id),
|
|
name="TestBrand",
|
|
aliases=["TestAlias"],
|
|
website="https://test.com",
|
|
industry="technology",
|
|
platforms=["wenxin", "kimi"],
|
|
frequency="weekly",
|
|
status="active",
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
await async_session.refresh(brand)
|
|
|
|
# Create a competitor
|
|
competitor = Competitor(
|
|
id=uuid.uuid4(),
|
|
brand_id=brand.id,
|
|
name="CompetitorA",
|
|
aliases=["CompA"],
|
|
)
|
|
async_session.add(competitor)
|
|
|
|
# Create a query
|
|
query = Query(
|
|
id=uuid.uuid4(),
|
|
user_id=_to_uuid(test_user.id),
|
|
keyword="AI assistant",
|
|
target_brand="TestBrand",
|
|
brand_aliases=["TestAlias"],
|
|
platforms=["wenxin", "kimi"],
|
|
frequency="weekly",
|
|
status="active",
|
|
last_queried_at=datetime.now(),
|
|
)
|
|
async_session.add(query)
|
|
await async_session.commit()
|
|
await async_session.refresh(query)
|
|
|
|
# Create citation records
|
|
citations = [
|
|
CitationRecord(
|
|
id=uuid.uuid4(),
|
|
query_id=query.id,
|
|
platform="wenxin",
|
|
cited=True,
|
|
citation_position=1,
|
|
citation_text="TestBrand is a leading AI company...",
|
|
competitor_brands=["CompetitorA"],
|
|
raw_response="{}",
|
|
confidence=0.95,
|
|
match_type="exact",
|
|
queried_at=datetime.now(),
|
|
),
|
|
CitationRecord(
|
|
id=uuid.uuid4(),
|
|
query_id=query.id,
|
|
platform="kimi",
|
|
cited=True,
|
|
citation_position=2,
|
|
citation_text="TestBrand offers great services...",
|
|
competitor_brands=["CompetitorA"],
|
|
raw_response="{}",
|
|
confidence=0.88,
|
|
match_type="alias",
|
|
queried_at=datetime.now(),
|
|
),
|
|
CitationRecord(
|
|
id=uuid.uuid4(),
|
|
query_id=query.id,
|
|
platform="wenxin",
|
|
cited=False,
|
|
citation_position=None,
|
|
citation_text=None,
|
|
competitor_brands=[],
|
|
raw_response="{}",
|
|
confidence=0.3,
|
|
match_type="none",
|
|
queried_at=datetime.now(),
|
|
),
|
|
]
|
|
for citation in citations:
|
|
async_session.add(citation)
|
|
|
|
await async_session.commit()
|
|
|
|
return brand
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_client(async_session, test_user):
|
|
"""Create async client for API testing."""
|
|
|
|
async def override_get_db():
|
|
yield async_session
|
|
|
|
async def override_get_current_user():
|
|
return test_user
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
app.dependency_overrides[get_current_user] = override_get_current_user
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
@pytest.fixture
|
|
def auth_headers(test_user):
|
|
"""Create authentication headers."""
|
|
token = create_access_token(data={"sub": str(test_user.id)})
|
|
return {"Authorization": f"Bearer {token}"}
|
|
|
|
|
|
class TestGetBrandScore:
|
|
"""Test GET /api/v1/brands/{brand_id}/score/ endpoint."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_brand_score_success(
|
|
self, async_client, test_brand
|
|
):
|
|
"""Test getting brand score successfully."""
|
|
response = await async_client.get(
|
|
f"/api/v1/brands/{test_brand.id}/score/",
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "mention_rate_score" in data
|
|
assert "sov_score" in data
|
|
assert "quality_score" in data
|
|
assert "overall_score" in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_brand_score_not_found(self, async_client):
|
|
"""Test getting score for non-existent brand."""
|
|
fake_id = uuid.uuid4()
|
|
response = await async_client.get(
|
|
f"/api/v1/brands/{fake_id}/score/",
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_brand_score_unauthorized(self, async_session, async_engine):
|
|
"""Test getting score without authorization."""
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
|
|
|
|
async_session_maker = async_sessionmaker(
|
|
async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
async with async_session_maker() as session:
|
|
async def override_get_db():
|
|
yield session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
# Remove auth override to test unauthorized
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
brand = Brand(
|
|
id=uuid.uuid4(),
|
|
user_id=uuid.uuid4(),
|
|
name="UnauthorizedBrand",
|
|
platforms=["wenxin"],
|
|
)
|
|
session.add(brand)
|
|
await session.commit()
|
|
|
|
response = await client.get(
|
|
f"/api/v1/brands/{brand.id}/score/",
|
|
)
|
|
|
|
assert response.status_code == 401
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
class TestGetBrandScoreHistory:
|
|
"""Test GET /api/v1/brands/{brand_id}/score/history/ endpoint."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_brand_score_history_success(
|
|
self, async_client, test_brand
|
|
):
|
|
"""Test getting brand score history successfully."""
|
|
response = await async_client.get(
|
|
f"/api/v1/brands/{test_brand.id}/score/history/",
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "history" in data
|
|
assert "total" in data
|
|
assert isinstance(data["history"], list)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_brand_score_history_not_found(
|
|
self, async_client
|
|
):
|
|
"""Test getting history for non-existent brand."""
|
|
fake_id = uuid.uuid4()
|
|
response = await async_client.get(
|
|
f"/api/v1/brands/{fake_id}/score/history/",
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_brand_score_history_with_pagination(
|
|
self, async_client, test_brand
|
|
):
|
|
"""Test score history pagination."""
|
|
response = await async_client.get(
|
|
f"/api/v1/brands/{test_brand.id}/score/history/?skip=0&limit=10",
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "history" in data
|
|
|
|
|
|
class TestGetCitationsStats:
|
|
"""Test GET /api/v1/citations/stats/ endpoint."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_citations_stats_success(
|
|
self, async_client
|
|
):
|
|
"""Test getting citations stats successfully."""
|
|
response = await async_client.get(
|
|
"/api/v1/citations/stats",
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "total_queries" in data
|
|
assert "total_citations" in data
|
|
assert "citation_rate" in data
|
|
assert "by_platform" in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_citations_stats_with_brand_filter(
|
|
self, async_client, test_brand
|
|
):
|
|
"""Test getting citations stats filtered by brand."""
|
|
response = await async_client.get(
|
|
f"/api/v1/citations/stats?brand_id={test_brand.id}",
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "total_queries" in data
|
|
assert "total_citations" in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_citations_stats_unauthorized(self, async_session, async_engine):
|
|
"""Test getting citations stats without authorization."""
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
|
|
|
|
async_session_maker = async_sessionmaker(
|
|
async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
async with async_session_maker() as session:
|
|
async def override_get_db():
|
|
yield session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
# Remove auth override to test unauthorized
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
response = await client.get("/api/v1/citations/stats")
|
|
|
|
assert response.status_code == 401
|
|
|
|
app.dependency_overrides.clear() |