"""Tests for scoring API endpoints.""" import uuid from datetime import datetime import pytest import pytest_asyncio from httpx import AsyncClient, ASGITransport from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession from app.database import Base from app.main import app from app.models.user import User from app.models.brand import Brand from app.models.competitor import Competitor from app.models.query import Query from app.models.citation_record import CitationRecord from app.api.deps import get_current_user, get_db from app.services.auth import create_access_token from tests.fixtures.auth import _to_uuid @pytest_asyncio.fixture async def async_engine(): """Create async engine for testing with SQLite.""" from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.pool import StaticPool engine = create_async_engine( "sqlite+aiosqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest_asyncio.fixture async def async_session(async_engine): """Create async session for testing.""" async_session_maker = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) async with async_session_maker() as session: yield session @pytest_asyncio.fixture async def test_user(async_session): """Create a test user.""" user = User( id=str(uuid.uuid4()), email="test@example.com", password="hashed_password", firstName="Test User", plan="free", max_queries=5, isActive=True, emailVerified=True, ) async_session.add(user) await async_session.commit() await async_session.refresh(user) return user @pytest_asyncio.fixture async def test_brand(async_session, test_user): """Create a test brand.""" brand = Brand( id=uuid.uuid4(), user_id=_to_uuid(test_user.id), name="TestBrand", aliases=["TestBrand", "TB"], website="https://testbrand.com", industry="technology", platforms=["wenxin", "kimi"], frequency="weekly", status="active", ) async_session.add(brand) await async_session.commit() await async_session.refresh(brand) return brand @pytest_asyncio.fixture async def test_brand_with_data(async_session: AsyncSession, test_user): """Create a test brand with query and citation data.""" brand = Brand( id=uuid.uuid4(), user_id=_to_uuid(test_user.id), name="TestBrand", aliases=["TestAlias"], website="https://test.com", industry="technology", platforms=["wenxin", "kimi"], frequency="weekly", status="active", ) async_session.add(brand) await async_session.commit() await async_session.refresh(brand) # Create a competitor competitor = Competitor( id=uuid.uuid4(), brand_id=brand.id, name="CompetitorA", aliases=["CompA"], ) async_session.add(competitor) # Create a query query = Query( id=uuid.uuid4(), user_id=_to_uuid(test_user.id), keyword="AI assistant", target_brand="TestBrand", brand_aliases=["TestAlias"], platforms=["wenxin", "kimi"], frequency="weekly", status="active", last_queried_at=datetime.now(), ) async_session.add(query) await async_session.commit() await async_session.refresh(query) # Create citation records citations = [ CitationRecord( id=uuid.uuid4(), query_id=query.id, platform="wenxin", cited=True, citation_position=1, citation_text="TestBrand is a leading AI company...", competitor_brands=["CompetitorA"], raw_response="{}", confidence=0.95, match_type="exact", queried_at=datetime.now(), ), CitationRecord( id=uuid.uuid4(), query_id=query.id, platform="kimi", cited=True, citation_position=2, citation_text="TestBrand offers great services...", competitor_brands=["CompetitorA"], raw_response="{}", confidence=0.88, match_type="alias", queried_at=datetime.now(), ), CitationRecord( id=uuid.uuid4(), query_id=query.id, platform="wenxin", cited=False, citation_position=None, citation_text=None, competitor_brands=[], raw_response="{}", confidence=0.3, match_type="none", queried_at=datetime.now(), ), ] for citation in citations: async_session.add(citation) await async_session.commit() return brand @pytest_asyncio.fixture async def async_client(async_session, test_user): """Create async client for API testing.""" async def override_get_db(): yield async_session async def override_get_current_user(): return test_user app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_current_user] = override_get_current_user transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client app.dependency_overrides.clear() @pytest.fixture def auth_headers(test_user): """Create authentication headers.""" token = create_access_token(data={"sub": str(test_user.id)}) return {"Authorization": f"Bearer {token}"} class TestGetBrandScore: """Test GET /api/v1/brands/{brand_id}/score/ endpoint.""" @pytest.mark.asyncio async def test_get_brand_score_success( self, async_client, test_brand ): """Test getting brand score successfully.""" response = await async_client.get( f"/api/v1/brands/{test_brand.id}/score/", ) assert response.status_code == 200 data = response.json() assert "mention_rate_score" in data assert "sov_score" in data assert "quality_score" in data assert "overall_score" in data @pytest.mark.asyncio async def test_get_brand_score_not_found(self, async_client): """Test getting score for non-existent brand.""" fake_id = uuid.uuid4() response = await async_client.get( f"/api/v1/brands/{fake_id}/score/", ) assert response.status_code == 404 @pytest.mark.asyncio async def test_get_brand_score_unauthorized(self, async_session, async_engine): """Test getting score without authorization.""" from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession async_session_maker = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) async with async_session_maker() as session: async def override_get_db(): yield session app.dependency_overrides[get_db] = override_get_db # Remove auth override to test unauthorized transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: brand = Brand( id=uuid.uuid4(), user_id=uuid.uuid4(), name="UnauthorizedBrand", platforms=["wenxin"], ) session.add(brand) await session.commit() response = await client.get( f"/api/v1/brands/{brand.id}/score/", ) assert response.status_code == 401 app.dependency_overrides.clear() class TestGetBrandScoreHistory: """Test GET /api/v1/brands/{brand_id}/score/history/ endpoint.""" @pytest.mark.asyncio async def test_get_brand_score_history_success( self, async_client, test_brand ): """Test getting brand score history successfully.""" response = await async_client.get( f"/api/v1/brands/{test_brand.id}/score/history/", ) assert response.status_code == 200 data = response.json() assert "history" in data assert "total" in data assert isinstance(data["history"], list) @pytest.mark.asyncio async def test_get_brand_score_history_not_found( self, async_client ): """Test getting history for non-existent brand.""" fake_id = uuid.uuid4() response = await async_client.get( f"/api/v1/brands/{fake_id}/score/history/", ) assert response.status_code == 404 @pytest.mark.asyncio async def test_get_brand_score_history_with_pagination( self, async_client, test_brand ): """Test score history pagination.""" response = await async_client.get( f"/api/v1/brands/{test_brand.id}/score/history/?skip=0&limit=10", ) assert response.status_code == 200 data = response.json() assert "history" in data class TestGetCitationsStats: """Test GET /api/v1/citations/stats/ endpoint.""" @pytest.mark.asyncio async def test_get_citations_stats_success( self, async_client ): """Test getting citations stats successfully.""" response = await async_client.get( "/api/v1/citations/stats", ) assert response.status_code == 200 data = response.json() assert "total_queries" in data assert "total_citations" in data assert "citation_rate" in data assert "by_platform" in data @pytest.mark.asyncio async def test_get_citations_stats_with_brand_filter( self, async_client, test_brand ): """Test getting citations stats filtered by brand.""" response = await async_client.get( f"/api/v1/citations/stats?brand_id={test_brand.id}", ) assert response.status_code == 200 data = response.json() assert "total_queries" in data assert "total_citations" in data @pytest.mark.asyncio async def test_get_citations_stats_unauthorized(self, async_session, async_engine): """Test getting citations stats without authorization.""" from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession async_session_maker = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) async with async_session_maker() as session: async def override_get_db(): yield session app.dependency_overrides[get_db] = override_get_db # Remove auth override to test unauthorized transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.get("/api/v1/citations/stats") assert response.status_code == 401 app.dependency_overrides.clear()