geo/backend/tests/test_services/test_ai_engine_query.py

546 lines
19 KiB
Python

import pytest
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
from app.services.ai_engine.base import (
AIEngineAdapter,
AIQueryResult,
CitationInfo,
EngineType,
)
from app.services.ai_engine.chatgpt import ChatGPTAdapter
from app.services.ai_engine.perplexity import PerplexityAdapter
class TestEngineType:
def test_engine_type_values(self):
assert EngineType.CHATGPT == "chatgpt"
assert EngineType.PERPLEXITY == "perplexity"
assert EngineType.KIMI == "kimi"
assert EngineType.WENXIN == "wenxin"
assert EngineType.DOUBAO == "doubao"
assert EngineType.DEEPSEEK == "deepseek"
assert EngineType.QWEN == "qwen"
assert EngineType.YUANBAO == "yuanbao"
class TestCitationInfo:
def test_citation_info_creation(self):
info = CitationInfo(
source_url="https://example.com",
source_title="Example Title",
citation_context="brand was mentioned here",
confidence=0.95,
position=1,
)
assert info.source_url == "https://example.com"
assert info.source_title == "Example Title"
assert info.citation_context == "brand was mentioned here"
assert info.confidence == 0.95
assert info.position == 1
def test_citation_info_optional_fields(self):
info = CitationInfo(
source_url=None,
source_title=None,
citation_context="some context",
confidence=0.5,
position=3,
)
assert info.source_url is None
assert info.source_title is None
class TestAIQueryResult:
def test_ai_query_result_creation(self):
now = datetime.now(UTC)
result = AIQueryResult(
engine_type=EngineType.CHATGPT,
query="best insurance companies",
raw_response="I recommend BrandX for insurance.",
citations=[],
has_brand_citation=True,
has_competitor_citation=False,
brand_context="I recommend BrandX for insurance.",
competitor_contexts=[],
response_time_ms=1500,
timestamp=now,
)
assert result.engine_type == EngineType.CHATGPT
assert result.query == "best insurance companies"
assert result.raw_response == "I recommend BrandX for insurance."
assert result.has_brand_citation is True
assert result.has_competitor_citation is False
assert result.brand_context == "I recommend BrandX for insurance."
assert result.competitor_contexts == []
assert result.response_time_ms == 1500
assert result.timestamp == now
def test_ai_query_result_with_citations(self):
citation = CitationInfo(
source_url="https://brandx.com",
source_title="BrandX Official",
citation_context="BrandX is a leading provider",
confidence=0.9,
position=1,
)
result = AIQueryResult(
engine_type=EngineType.PERPLEXITY,
query="insurance comparison",
raw_response="BrandX is great",
citations=[citation],
has_brand_citation=True,
has_competitor_citation=False,
brand_context="BrandX is great",
competitor_contexts=[],
response_time_ms=2000,
timestamp=datetime.now(UTC),
)
assert len(result.citations) == 1
assert result.citations[0].source_url == "https://brandx.com"
def test_ai_query_result_default_metadata(self):
result = AIQueryResult(
engine_type=EngineType.CHATGPT,
query="test",
raw_response="test",
citations=[],
has_brand_citation=False,
has_competitor_citation=False,
brand_context=None,
competitor_contexts=[],
response_time_ms=100,
timestamp=datetime.now(UTC),
)
assert result.metadata == {}
class TestAIEngineAdapterBase:
def test_cannot_instantiate_abstract_class(self):
with pytest.raises(TypeError):
AIEngineAdapter(api_key="test-key")
def test_detect_citations_brand_found(self):
class ConcreteAdapter(AIEngineAdapter):
async def query(self, query, brand_name, competitor_names=None):
pass
def get_engine_type(self):
return EngineType.CHATGPT
def _get_env_key(self) -> str | None:
return None
adapter = ConcreteAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"BrandX is the best insurance company",
"BrandX",
None,
)
assert has_brand is True
assert has_comp is False
assert brand_ctx is not None
assert "BrandX" in brand_ctx
def test_detect_citations_competitor_found(self):
class ConcreteAdapter(AIEngineAdapter):
async def query(self, query, brand_name, competitor_names=None):
pass
def get_engine_type(self):
return EngineType.CHATGPT
def _get_env_key(self) -> str | None:
return None
adapter = ConcreteAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"CompetitorY is also a good choice for insurance",
"BrandX",
["CompetitorY", "CompetitorZ"],
)
assert has_brand is False
assert has_comp is True
assert len(comp_ctx) == 1
assert "CompetitorY" in comp_ctx[0]
def test_detect_citations_both_found(self):
class ConcreteAdapter(AIEngineAdapter):
async def query(self, query, brand_name, competitor_names=None):
pass
def get_engine_type(self):
return EngineType.CHATGPT
def _get_env_key(self) -> str | None:
return None
adapter = ConcreteAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"BrandX and CompetitorY are both good insurance options",
"BrandX",
["CompetitorY"],
)
assert has_brand is True
assert has_comp is True
assert brand_ctx is not None
assert len(comp_ctx) == 1
def test_detect_citations_none_found(self):
class ConcreteAdapter(AIEngineAdapter):
async def query(self, query, brand_name, competitor_names=None):
pass
def get_engine_type(self):
return EngineType.CHATGPT
def _get_env_key(self) -> str | None:
return None
adapter = ConcreteAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"Some random text without brand names",
"BrandX",
["CompetitorY"],
)
assert has_brand is False
assert has_comp is False
assert brand_ctx is None
assert comp_ctx == []
def test_detect_citations_case_insensitive(self):
class ConcreteAdapter(AIEngineAdapter):
async def query(self, query, brand_name, competitor_names=None):
pass
def get_engine_type(self):
return EngineType.CHATGPT
def _get_env_key(self) -> str | None:
return None
adapter = ConcreteAdapter(api_key="test-key")
has_brand, _, _, _ = adapter._detect_citations(
"brandx is great",
"BrandX",
None,
)
assert has_brand is True
class TestChatGPTAdapter:
@pytest.fixture
def chatgpt_adapter(self):
return ChatGPTAdapter(api_key="test-api-key")
def test_chatgpt_init(self, chatgpt_adapter):
assert chatgpt_adapter.api_key == "test-api-key"
assert chatgpt_adapter.get_engine_type() == EngineType.CHATGPT
@pytest.mark.asyncio
async def test_chatgpt_query_success(self, chatgpt_adapter):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{
"message": {
"content": "BrandX is a leading insurance company with great service."
}
}
],
"model": "gpt-4o",
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
}
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await chatgpt_adapter.query(
query="best insurance companies",
brand_name="BrandX",
competitor_names=["CompetitorY"],
)
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.CHATGPT
assert result.query == "best insurance companies"
assert "BrandX" in result.raw_response
assert result.has_brand_citation is True
assert result.response_time_ms >= 0
@pytest.mark.asyncio
async def test_chatgpt_query_with_rate_limiter(self):
mock_limiter = AsyncMock()
adapter = ChatGPTAdapter(api_key="test-key", rate_limiter=mock_limiter)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "Some response"}}],
"model": "gpt-4o",
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
await adapter.query(query="test", brand_name="BrandX")
mock_limiter.acquire.assert_awaited()
@pytest.mark.asyncio
async def test_chatgpt_query_api_timeout(self, chatgpt_adapter):
import httpx
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.side_effect = httpx.TimeoutException("Request timed out")
with pytest.raises(Exception):
await chatgpt_adapter.query(
query="test query",
brand_name="BrandX",
)
@pytest.mark.asyncio
async def test_chatgpt_query_invalid_response(self, chatgpt_adapter):
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.text = "Unauthorized"
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
with pytest.raises(Exception):
await chatgpt_adapter.query(
query="test query",
brand_name="BrandX",
)
@pytest.mark.asyncio
async def test_chatgpt_brand_citation_detection(self, chatgpt_adapter):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{"message": {"content": "BrandX offers excellent insurance coverage."}}
],
"model": "gpt-4o",
}
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await chatgpt_adapter.query(
query="insurance",
brand_name="BrandX",
)
assert result.has_brand_citation is True
assert result.brand_context is not None
assert "BrandX" in result.brand_context
@pytest.mark.asyncio
async def test_chatgpt_competitor_citation_detection(self, chatgpt_adapter):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{
"message": {
"content": "CompetitorY and CompetitorZ are popular insurance providers."
}
}
],
"model": "gpt-4o",
}
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await chatgpt_adapter.query(
query="insurance",
brand_name="BrandX",
competitor_names=["CompetitorY", "CompetitorZ"],
)
assert result.has_brand_citation is False
assert result.has_competitor_citation is True
assert len(result.competitor_contexts) == 2
class TestPerplexityAdapter:
@pytest.fixture
def perplexity_adapter(self):
return PerplexityAdapter(api_key="test-api-key")
def test_perplexity_init(self, perplexity_adapter):
assert perplexity_adapter.api_key == "test-api-key"
assert perplexity_adapter.get_engine_type() == EngineType.PERPLEXITY
@pytest.mark.asyncio
async def test_perplexity_query_success(self, perplexity_adapter):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{
"message": {
"content": "BrandX is a well-known insurance brand [1]."
}
}
],
"citations": [
{"url": "https://brandx.com", "title": "BrandX Official Site"}
],
"model": "pplx-70b-online",
}
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await perplexity_adapter.query(
query="best insurance companies",
brand_name="BrandX",
competitor_names=["CompetitorY"],
)
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.PERPLEXITY
assert result.query == "best insurance companies"
assert "BrandX" in result.raw_response
assert result.has_brand_citation is True
assert len(result.citations) >= 1
@pytest.mark.asyncio
async def test_perplexity_query_with_citations(self, perplexity_adapter):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{"message": {"content": "BrandX is recommended [1]. CompetitorY is also good [2]."}}
],
"citations": [
{"url": "https://brandx.com", "title": "BrandX"},
{"url": "https://competitory.com", "title": "CompetitorY"},
],
"model": "pplx-70b-online",
}
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await perplexity_adapter.query(
query="insurance",
brand_name="BrandX",
competitor_names=["CompetitorY"],
)
assert len(result.citations) == 2
assert result.citations[0].source_url == "https://brandx.com"
assert result.citations[1].source_url == "https://competitory.com"
@pytest.mark.asyncio
async def test_perplexity_query_with_rate_limiter(self):
mock_limiter = AsyncMock()
adapter = PerplexityAdapter(api_key="test-key", rate_limiter=mock_limiter)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "Some response"}}],
"model": "pplx-70b-online",
"citations": [],
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
await adapter.query(query="test", brand_name="BrandX")
mock_limiter.acquire.assert_awaited()
@pytest.mark.asyncio
async def test_perplexity_query_api_timeout(self, perplexity_adapter):
import httpx
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.side_effect = httpx.TimeoutException("Request timed out")
with pytest.raises(Exception):
await perplexity_adapter.query(
query="test query",
brand_name="BrandX",
)
@pytest.mark.asyncio
async def test_perplexity_query_invalid_response(self, perplexity_adapter):
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.text = "Unauthorized"
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
with pytest.raises(Exception):
await perplexity_adapter.query(
query="test query",
brand_name="BrandX",
)
@pytest.mark.asyncio
async def test_perplexity_brand_citation_detection(self, perplexity_adapter):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{"message": {"content": "BrandX offers excellent insurance coverage."}}
],
"citations": [],
"model": "pplx-70b-online",
}
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await perplexity_adapter.query(
query="insurance",
brand_name="BrandX",
)
assert result.has_brand_citation is True
assert result.brand_context is not None
@pytest.mark.asyncio
async def test_perplexity_competitor_citation_detection(self, perplexity_adapter):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{"message": {"content": "CompetitorY is a popular insurance provider."}}
],
"citations": [],
"model": "pplx-70b-online",
}
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await perplexity_adapter.query(
query="insurance",
brand_name="BrandX",
competitor_names=["CompetitorY"],
)
assert result.has_brand_citation is False
assert result.has_competitor_citation is True
assert len(result.competitor_contexts) == 1
@pytest.mark.asyncio
async def test_perplexity_no_citations_field(self, perplexity_adapter):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "Some response without citations field"}}],
"model": "pplx-70b-online",
}
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await perplexity_adapter.query(
query="test",
brand_name="BrandX",
)
assert result.citations == []