530 lines
19 KiB
Python
530 lines
19 KiB
Python
import pytest
|
|
from datetime import UTC, datetime
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from app.services.ai_engine.base import (
|
|
AIEngineAdapter,
|
|
AIQueryResult,
|
|
CitationInfo,
|
|
EngineType,
|
|
)
|
|
from app.services.ai_engine.chatgpt import ChatGPTAdapter
|
|
from app.services.ai_engine.perplexity import PerplexityAdapter
|
|
|
|
|
|
class TestEngineType:
|
|
def test_engine_type_values(self):
|
|
assert EngineType.CHATGPT == "chatgpt"
|
|
assert EngineType.PERPLEXITY == "perplexity"
|
|
assert EngineType.KIMI == "kimi"
|
|
assert EngineType.WENXIN == "wenxin"
|
|
assert EngineType.DOUBAO == "doubao"
|
|
assert EngineType.DEEPSEEK == "deepseek"
|
|
assert EngineType.QWEN == "qwen"
|
|
|
|
|
|
class TestCitationInfo:
|
|
def test_citation_info_creation(self):
|
|
info = CitationInfo(
|
|
source_url="https://example.com",
|
|
source_title="Example Title",
|
|
citation_context="brand was mentioned here",
|
|
confidence=0.95,
|
|
position=1,
|
|
)
|
|
assert info.source_url == "https://example.com"
|
|
assert info.source_title == "Example Title"
|
|
assert info.citation_context == "brand was mentioned here"
|
|
assert info.confidence == 0.95
|
|
assert info.position == 1
|
|
|
|
def test_citation_info_optional_fields(self):
|
|
info = CitationInfo(
|
|
source_url=None,
|
|
source_title=None,
|
|
citation_context="some context",
|
|
confidence=0.5,
|
|
position=3,
|
|
)
|
|
assert info.source_url is None
|
|
assert info.source_title is None
|
|
|
|
|
|
class TestAIQueryResult:
|
|
def test_ai_query_result_creation(self):
|
|
now = datetime.now(UTC)
|
|
result = AIQueryResult(
|
|
engine_type=EngineType.CHATGPT,
|
|
query="best insurance companies",
|
|
raw_response="I recommend BrandX for insurance.",
|
|
citations=[],
|
|
has_brand_citation=True,
|
|
has_competitor_citation=False,
|
|
brand_context="I recommend BrandX for insurance.",
|
|
competitor_contexts=[],
|
|
response_time_ms=1500,
|
|
timestamp=now,
|
|
)
|
|
assert result.engine_type == EngineType.CHATGPT
|
|
assert result.query == "best insurance companies"
|
|
assert result.raw_response == "I recommend BrandX for insurance."
|
|
assert result.has_brand_citation is True
|
|
assert result.has_competitor_citation is False
|
|
assert result.brand_context == "I recommend BrandX for insurance."
|
|
assert result.competitor_contexts == []
|
|
assert result.response_time_ms == 1500
|
|
assert result.timestamp == now
|
|
|
|
def test_ai_query_result_with_citations(self):
|
|
citation = CitationInfo(
|
|
source_url="https://brandx.com",
|
|
source_title="BrandX Official",
|
|
citation_context="BrandX is a leading provider",
|
|
confidence=0.9,
|
|
position=1,
|
|
)
|
|
result = AIQueryResult(
|
|
engine_type=EngineType.PERPLEXITY,
|
|
query="insurance comparison",
|
|
raw_response="BrandX is great",
|
|
citations=[citation],
|
|
has_brand_citation=True,
|
|
has_competitor_citation=False,
|
|
brand_context="BrandX is great",
|
|
competitor_contexts=[],
|
|
response_time_ms=2000,
|
|
timestamp=datetime.now(UTC),
|
|
)
|
|
assert len(result.citations) == 1
|
|
assert result.citations[0].source_url == "https://brandx.com"
|
|
|
|
def test_ai_query_result_default_metadata(self):
|
|
result = AIQueryResult(
|
|
engine_type=EngineType.CHATGPT,
|
|
query="test",
|
|
raw_response="test",
|
|
citations=[],
|
|
has_brand_citation=False,
|
|
has_competitor_citation=False,
|
|
brand_context=None,
|
|
competitor_contexts=[],
|
|
response_time_ms=100,
|
|
timestamp=datetime.now(UTC),
|
|
)
|
|
assert result.metadata == {}
|
|
|
|
|
|
class TestAIEngineAdapterBase:
|
|
def test_cannot_instantiate_abstract_class(self):
|
|
with pytest.raises(TypeError):
|
|
AIEngineAdapter(api_key="test-key")
|
|
|
|
def test_detect_citations_brand_found(self):
|
|
class ConcreteAdapter(AIEngineAdapter):
|
|
async def query(self, query, brand_name, competitor_names=None):
|
|
pass
|
|
|
|
def get_engine_type(self):
|
|
return EngineType.CHATGPT
|
|
|
|
adapter = ConcreteAdapter(api_key="test-key")
|
|
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
|
"BrandX is the best insurance company",
|
|
"BrandX",
|
|
None,
|
|
)
|
|
assert has_brand is True
|
|
assert has_comp is False
|
|
assert brand_ctx is not None
|
|
assert "BrandX" in brand_ctx
|
|
|
|
def test_detect_citations_competitor_found(self):
|
|
class ConcreteAdapter(AIEngineAdapter):
|
|
async def query(self, query, brand_name, competitor_names=None):
|
|
pass
|
|
|
|
def get_engine_type(self):
|
|
return EngineType.CHATGPT
|
|
|
|
adapter = ConcreteAdapter(api_key="test-key")
|
|
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
|
"CompetitorY is also a good choice for insurance",
|
|
"BrandX",
|
|
["CompetitorY", "CompetitorZ"],
|
|
)
|
|
assert has_brand is False
|
|
assert has_comp is True
|
|
assert len(comp_ctx) == 1
|
|
assert "CompetitorY" in comp_ctx[0]
|
|
|
|
def test_detect_citations_both_found(self):
|
|
class ConcreteAdapter(AIEngineAdapter):
|
|
async def query(self, query, brand_name, competitor_names=None):
|
|
pass
|
|
|
|
def get_engine_type(self):
|
|
return EngineType.CHATGPT
|
|
|
|
adapter = ConcreteAdapter(api_key="test-key")
|
|
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
|
"BrandX and CompetitorY are both good insurance options",
|
|
"BrandX",
|
|
["CompetitorY"],
|
|
)
|
|
assert has_brand is True
|
|
assert has_comp is True
|
|
assert brand_ctx is not None
|
|
assert len(comp_ctx) == 1
|
|
|
|
def test_detect_citations_none_found(self):
|
|
class ConcreteAdapter(AIEngineAdapter):
|
|
async def query(self, query, brand_name, competitor_names=None):
|
|
pass
|
|
|
|
def get_engine_type(self):
|
|
return EngineType.CHATGPT
|
|
|
|
adapter = ConcreteAdapter(api_key="test-key")
|
|
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
|
"Some random text without brand names",
|
|
"BrandX",
|
|
["CompetitorY"],
|
|
)
|
|
assert has_brand is False
|
|
assert has_comp is False
|
|
assert brand_ctx is None
|
|
assert comp_ctx == []
|
|
|
|
def test_detect_citations_case_insensitive(self):
|
|
class ConcreteAdapter(AIEngineAdapter):
|
|
async def query(self, query, brand_name, competitor_names=None):
|
|
pass
|
|
|
|
def get_engine_type(self):
|
|
return EngineType.CHATGPT
|
|
|
|
adapter = ConcreteAdapter(api_key="test-key")
|
|
has_brand, _, _, _ = adapter._detect_citations(
|
|
"brandx is great",
|
|
"BrandX",
|
|
None,
|
|
)
|
|
assert has_brand is True
|
|
|
|
|
|
class TestChatGPTAdapter:
|
|
@pytest.fixture
|
|
def chatgpt_adapter(self):
|
|
return ChatGPTAdapter(api_key="test-api-key")
|
|
|
|
def test_chatgpt_init(self, chatgpt_adapter):
|
|
assert chatgpt_adapter.api_key == "test-api-key"
|
|
assert chatgpt_adapter.get_engine_type() == EngineType.CHATGPT
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chatgpt_query_success(self, chatgpt_adapter):
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": "BrandX is a leading insurance company with great service."
|
|
}
|
|
}
|
|
],
|
|
"model": "gpt-4o",
|
|
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
|
}
|
|
|
|
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
result = await chatgpt_adapter.query(
|
|
query="best insurance companies",
|
|
brand_name="BrandX",
|
|
competitor_names=["CompetitorY"],
|
|
)
|
|
|
|
assert isinstance(result, AIQueryResult)
|
|
assert result.engine_type == EngineType.CHATGPT
|
|
assert result.query == "best insurance companies"
|
|
assert "BrandX" in result.raw_response
|
|
assert result.has_brand_citation is True
|
|
assert result.response_time_ms >= 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chatgpt_query_with_rate_limiter(self):
|
|
mock_limiter = AsyncMock()
|
|
adapter = ChatGPTAdapter(api_key="test-key", rate_limiter=mock_limiter)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [{"message": {"content": "Some response"}}],
|
|
"model": "gpt-4o",
|
|
}
|
|
|
|
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
await adapter.query(query="test", brand_name="BrandX")
|
|
|
|
mock_limiter.acquire.assert_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chatgpt_query_api_timeout(self, chatgpt_adapter):
|
|
import httpx
|
|
|
|
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.side_effect = httpx.TimeoutException("Request timed out")
|
|
|
|
with pytest.raises(Exception):
|
|
await chatgpt_adapter.query(
|
|
query="test query",
|
|
brand_name="BrandX",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chatgpt_query_invalid_response(self, chatgpt_adapter):
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 401
|
|
mock_response.text = "Unauthorized"
|
|
|
|
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
|
|
with pytest.raises(Exception):
|
|
await chatgpt_adapter.query(
|
|
query="test query",
|
|
brand_name="BrandX",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chatgpt_brand_citation_detection(self, chatgpt_adapter):
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [
|
|
{"message": {"content": "BrandX offers excellent insurance coverage."}}
|
|
],
|
|
"model": "gpt-4o",
|
|
}
|
|
|
|
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
result = await chatgpt_adapter.query(
|
|
query="insurance",
|
|
brand_name="BrandX",
|
|
)
|
|
|
|
assert result.has_brand_citation is True
|
|
assert result.brand_context is not None
|
|
assert "BrandX" in result.brand_context
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chatgpt_competitor_citation_detection(self, chatgpt_adapter):
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": "CompetitorY and CompetitorZ are popular insurance providers."
|
|
}
|
|
}
|
|
],
|
|
"model": "gpt-4o",
|
|
}
|
|
|
|
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
result = await chatgpt_adapter.query(
|
|
query="insurance",
|
|
brand_name="BrandX",
|
|
competitor_names=["CompetitorY", "CompetitorZ"],
|
|
)
|
|
|
|
assert result.has_brand_citation is False
|
|
assert result.has_competitor_citation is True
|
|
assert len(result.competitor_contexts) == 2
|
|
|
|
|
|
class TestPerplexityAdapter:
|
|
@pytest.fixture
|
|
def perplexity_adapter(self):
|
|
return PerplexityAdapter(api_key="test-api-key")
|
|
|
|
def test_perplexity_init(self, perplexity_adapter):
|
|
assert perplexity_adapter.api_key == "test-api-key"
|
|
assert perplexity_adapter.get_engine_type() == EngineType.PERPLEXITY
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_perplexity_query_success(self, perplexity_adapter):
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": "BrandX is a well-known insurance brand [1]."
|
|
}
|
|
}
|
|
],
|
|
"citations": [
|
|
{"url": "https://brandx.com", "title": "BrandX Official Site"}
|
|
],
|
|
"model": "pplx-70b-online",
|
|
}
|
|
|
|
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
result = await perplexity_adapter.query(
|
|
query="best insurance companies",
|
|
brand_name="BrandX",
|
|
competitor_names=["CompetitorY"],
|
|
)
|
|
|
|
assert isinstance(result, AIQueryResult)
|
|
assert result.engine_type == EngineType.PERPLEXITY
|
|
assert result.query == "best insurance companies"
|
|
assert "BrandX" in result.raw_response
|
|
assert result.has_brand_citation is True
|
|
assert len(result.citations) >= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_perplexity_query_with_citations(self, perplexity_adapter):
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [
|
|
{"message": {"content": "BrandX is recommended [1]. CompetitorY is also good [2]."}}
|
|
],
|
|
"citations": [
|
|
{"url": "https://brandx.com", "title": "BrandX"},
|
|
{"url": "https://competitory.com", "title": "CompetitorY"},
|
|
],
|
|
"model": "pplx-70b-online",
|
|
}
|
|
|
|
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
result = await perplexity_adapter.query(
|
|
query="insurance",
|
|
brand_name="BrandX",
|
|
competitor_names=["CompetitorY"],
|
|
)
|
|
|
|
assert len(result.citations) == 2
|
|
assert result.citations[0].source_url == "https://brandx.com"
|
|
assert result.citations[1].source_url == "https://competitory.com"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_perplexity_query_with_rate_limiter(self):
|
|
mock_limiter = AsyncMock()
|
|
adapter = PerplexityAdapter(api_key="test-key", rate_limiter=mock_limiter)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [{"message": {"content": "Some response"}}],
|
|
"model": "pplx-70b-online",
|
|
"citations": [],
|
|
}
|
|
|
|
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
await adapter.query(query="test", brand_name="BrandX")
|
|
|
|
mock_limiter.acquire.assert_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_perplexity_query_api_timeout(self, perplexity_adapter):
|
|
import httpx
|
|
|
|
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.side_effect = httpx.TimeoutException("Request timed out")
|
|
|
|
with pytest.raises(Exception):
|
|
await perplexity_adapter.query(
|
|
query="test query",
|
|
brand_name="BrandX",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_perplexity_query_invalid_response(self, perplexity_adapter):
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 401
|
|
mock_response.text = "Unauthorized"
|
|
|
|
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
|
|
with pytest.raises(Exception):
|
|
await perplexity_adapter.query(
|
|
query="test query",
|
|
brand_name="BrandX",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_perplexity_brand_citation_detection(self, perplexity_adapter):
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [
|
|
{"message": {"content": "BrandX offers excellent insurance coverage."}}
|
|
],
|
|
"citations": [],
|
|
"model": "pplx-70b-online",
|
|
}
|
|
|
|
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
result = await perplexity_adapter.query(
|
|
query="insurance",
|
|
brand_name="BrandX",
|
|
)
|
|
|
|
assert result.has_brand_citation is True
|
|
assert result.brand_context is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_perplexity_competitor_citation_detection(self, perplexity_adapter):
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [
|
|
{"message": {"content": "CompetitorY is a popular insurance provider."}}
|
|
],
|
|
"citations": [],
|
|
"model": "pplx-70b-online",
|
|
}
|
|
|
|
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
result = await perplexity_adapter.query(
|
|
query="insurance",
|
|
brand_name="BrandX",
|
|
competitor_names=["CompetitorY"],
|
|
)
|
|
|
|
assert result.has_brand_citation is False
|
|
assert result.has_competitor_citation is True
|
|
assert len(result.competitor_contexts) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_perplexity_no_citations_field(self, perplexity_adapter):
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [{"message": {"content": "Some response without citations field"}}],
|
|
"model": "pplx-70b-online",
|
|
}
|
|
|
|
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
result = await perplexity_adapter.query(
|
|
query="test",
|
|
brand_name="BrandX",
|
|
)
|
|
|
|
assert result.citations == []
|