217 lines
7.1 KiB
Python
217 lines
7.1 KiB
Python
import uuid
|
|
from datetime import UTC, datetime
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from httpx import ASGITransport, AsyncClient
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.database import Base
|
|
from app.main import app
|
|
from app.models.user import User
|
|
from app.api.deps import get_current_user, get_db
|
|
from app.services.auth import hash_password
|
|
from app.services.ai_engine.base import AIQueryResult, EngineType
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_engine():
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield engine
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_session(async_engine):
|
|
async_session_maker = async_sessionmaker(
|
|
async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
async with async_session_maker() as session:
|
|
yield session
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_user(async_session):
|
|
user = User(
|
|
id=uuid.uuid4(),
|
|
email="test@example.com",
|
|
password_hash=hash_password("Test@123456"),
|
|
name="Test User",
|
|
plan="free",
|
|
max_queries=5,
|
|
is_active=True,
|
|
email_verified=True,
|
|
)
|
|
async_session.add(user)
|
|
await async_session.commit()
|
|
await async_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_client(async_session, test_user):
|
|
async def override_get_db():
|
|
yield async_session
|
|
|
|
async def override_get_current_user():
|
|
return test_user
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
app.dependency_overrides[get_current_user] = override_get_current_user
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
def _make_result(
|
|
engine_type: EngineType,
|
|
query: str = "best insurance",
|
|
has_brand: bool = False,
|
|
has_competitor: bool = False,
|
|
) -> AIQueryResult:
|
|
return AIQueryResult(
|
|
engine_type=engine_type,
|
|
query=query,
|
|
raw_response="BrandX is a great insurance company",
|
|
citations=[],
|
|
has_brand_citation=has_brand,
|
|
has_competitor_citation=has_competitor,
|
|
brand_context="BrandX is great" if has_brand else None,
|
|
competitor_contexts=["CompY is ok"] if has_competitor else [],
|
|
response_time_ms=150,
|
|
timestamp=datetime.now(UTC),
|
|
)
|
|
|
|
|
|
class TestSingleQueryEndpoint:
|
|
@pytest.mark.asyncio
|
|
async def test_query_single_engine(self, async_client):
|
|
mock_result = _make_result(EngineType.CHATGPT, has_brand=True)
|
|
with patch("app.api.ai_engines.get_batch_service") as mock_get_service:
|
|
mock_service = AsyncMock()
|
|
mock_service.query_single.return_value = mock_result
|
|
mock_get_service.return_value = mock_service
|
|
|
|
response = await async_client.post(
|
|
"/api/v1/ai-engines/query",
|
|
json={
|
|
"engine": "chatgpt",
|
|
"query": "best insurance",
|
|
"brand_name": "BrandX",
|
|
"competitor_names": ["CompY"],
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["engine_type"] == "chatgpt"
|
|
assert data["has_brand_citation"] is True
|
|
assert data["query"] == "best insurance"
|
|
|
|
|
|
class TestBatchQueryEndpoint:
|
|
@pytest.mark.asyncio
|
|
async def test_query_batch_parallel(self, async_client):
|
|
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
|
|
r2 = _make_result(EngineType.PERPLEXITY, has_brand=False, has_competitor=True)
|
|
with patch("app.api.ai_engines.get_batch_service") as mock_get_service:
|
|
mock_service = AsyncMock()
|
|
mock_service.query_batch.return_value = [r1, r2]
|
|
mock_service.calculate_citation_rate = MagicMock(return_value={
|
|
"total_engines": 2,
|
|
"brand_citation_count": 1,
|
|
"brand_citation_rate": 0.5,
|
|
"competitor_citation_count": 1,
|
|
"competitor_citation_rate": 0.5,
|
|
})
|
|
mock_get_service.return_value = mock_service
|
|
|
|
response = await async_client.post(
|
|
"/api/v1/ai-engines/query-batch",
|
|
json={
|
|
"engines": ["chatgpt", "perplexity"],
|
|
"query": "best insurance",
|
|
"brand_name": "BrandX",
|
|
"competitor_names": ["CompY"],
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "results" in data
|
|
assert "citation_rate" in data
|
|
assert len(data["results"]) == 2
|
|
assert data["citation_rate"]["brand_citation_rate"] == 0.5
|
|
|
|
|
|
class TestGetResultsEndpoint:
|
|
@pytest.mark.asyncio
|
|
async def test_get_results(self, async_client):
|
|
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
|
|
r2 = _make_result(EngineType.KIMI, has_brand=False)
|
|
with patch("app.api.ai_engines.get_batch_service") as mock_get_service:
|
|
mock_service = AsyncMock()
|
|
mock_service.query_batch.return_value = [r1, r2]
|
|
mock_service.calculate_citation_rate = MagicMock(return_value={
|
|
"total_engines": 2,
|
|
"brand_citation_count": 1,
|
|
"brand_citation_rate": 0.5,
|
|
"competitor_citation_count": 0,
|
|
"competitor_citation_rate": 0.0,
|
|
})
|
|
mock_get_service.return_value = mock_service
|
|
|
|
response = await async_client.get(
|
|
"/api/v1/ai-engines/results",
|
|
params={
|
|
"engines": "chatgpt,kimi",
|
|
"query": "best insurance",
|
|
"brand_name": "BrandX",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "results" in data
|
|
assert "citation_rate" in data
|
|
|
|
|
|
class TestUnauthorizedAccess:
|
|
@pytest.mark.asyncio
|
|
async def test_unauthorized_returns_401(self, async_session):
|
|
async def override_get_db():
|
|
yield async_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
headers = {"Authorization": "Bearer invalid_token"}
|
|
response = await client.post(
|
|
"/api/v1/ai-engines/query",
|
|
json={
|
|
"engine": "chatgpt",
|
|
"query": "test",
|
|
"brand_name": "BrandX",
|
|
},
|
|
headers=headers,
|
|
)
|
|
assert response.status_code == 401
|
|
|
|
app.dependency_overrides.clear()
|