289 lines
11 KiB
Python
289 lines
11 KiB
Python
import os
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
|
|
from app.services.ai_engine.chatgpt import ChatGPTAdapter
|
|
from app.services.ai_engine.deepseek import DeepSeekAdapter
|
|
from app.services.ai_engine.perplexity import PerplexityAdapter
|
|
|
|
|
|
class _ConcreteAdapter(AIEngineAdapter):
|
|
async def query(self, query, brand_name, competitor_names=None):
|
|
pass
|
|
|
|
def get_engine_type(self):
|
|
return EngineType.CHATGPT
|
|
|
|
def _get_env_key(self) -> str | None:
|
|
return None
|
|
|
|
|
|
class TestProxySupportBase:
|
|
def test_base_init_with_proxy(self):
|
|
adapter = _ConcreteAdapter(api_key="test-key", proxy="http://proxy:8080")
|
|
assert adapter.proxy == "http://proxy:8080"
|
|
|
|
def test_base_init_proxy_default_none(self):
|
|
adapter = _ConcreteAdapter(api_key="test-key")
|
|
assert adapter.proxy is None
|
|
|
|
def test_base_init_proxy_from_https_proxy_env(self):
|
|
with patch.dict(os.environ, {"HTTPS_PROXY": "http://env-proxy:3128"}):
|
|
adapter = _ConcreteAdapter(api_key="test-key")
|
|
assert adapter.proxy == "http://env-proxy:3128"
|
|
|
|
def test_base_init_proxy_from_https_proxy_lowercase_env(self):
|
|
with patch.dict(os.environ, {"https_proxy": "http://lower-proxy:3128"}, clear=False):
|
|
if "HTTPS_PROXY" in os.environ:
|
|
del os.environ["HTTPS_PROXY"]
|
|
adapter = _ConcreteAdapter(api_key="test-key")
|
|
assert adapter.proxy == "http://lower-proxy:3128"
|
|
|
|
def test_base_init_explicit_proxy_overrides_env(self):
|
|
with patch.dict(os.environ, {"HTTPS_PROXY": "http://env-proxy:3128"}):
|
|
adapter = _ConcreteAdapter(
|
|
api_key="test-key", proxy="http://explicit-proxy:8080"
|
|
)
|
|
assert adapter.proxy == "http://explicit-proxy:8080"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_client_with_proxy(self):
|
|
adapter = _ConcreteAdapter(api_key="test-key", proxy="http://proxy:8080")
|
|
client = await adapter._get_client()
|
|
assert isinstance(client, httpx.AsyncClient)
|
|
await adapter.close()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_client_without_proxy(self):
|
|
adapter = _ConcreteAdapter(api_key="test-key")
|
|
client = await adapter._get_client()
|
|
assert isinstance(client, httpx.AsyncClient)
|
|
await adapter.close()
|
|
|
|
|
|
class TestChatGPTProxy:
|
|
def test_chatgpt_proxy_parameter(self):
|
|
adapter = ChatGPTAdapter(api_key="test-key", proxy="http://proxy:8080")
|
|
assert adapter.proxy == "http://proxy:8080"
|
|
|
|
def test_chatgpt_proxy_from_openai_proxy_env(self):
|
|
with patch.dict(os.environ, {"OPENAI_PROXY": "http://openai-proxy:8080"}):
|
|
adapter = ChatGPTAdapter(api_key="test-key")
|
|
assert adapter.proxy == "http://openai-proxy:8080"
|
|
|
|
def test_chatgpt_explicit_proxy_overrides_env(self):
|
|
with patch.dict(os.environ, {"OPENAI_PROXY": "http://env-proxy:8080"}):
|
|
adapter = ChatGPTAdapter(
|
|
api_key="test-key", proxy="http://explicit:9090"
|
|
)
|
|
assert adapter.proxy == "http://explicit:9090"
|
|
|
|
def test_chatgpt_fallback_to_https_proxy_env(self):
|
|
with patch.dict(os.environ, {"HTTPS_PROXY": "http://fallback:3128"}, clear=False):
|
|
if "OPENAI_PROXY" in os.environ:
|
|
del os.environ["OPENAI_PROXY"]
|
|
adapter = ChatGPTAdapter(api_key="test-key")
|
|
assert adapter.proxy == "http://fallback:3128"
|
|
|
|
def test_chatgpt_no_proxy(self):
|
|
with patch.dict(os.environ, {}, clear=False):
|
|
for key in ("OPENAI_PROXY", "HTTPS_PROXY", "https_proxy"):
|
|
os.environ.pop(key, None)
|
|
adapter = ChatGPTAdapter(api_key="test-key")
|
|
assert adapter.proxy is None
|
|
|
|
|
|
class TestPerplexityProxy:
|
|
def test_perplexity_proxy_parameter(self):
|
|
adapter = PerplexityAdapter(api_key="test-key", proxy="http://proxy:8080")
|
|
assert adapter.proxy == "http://proxy:8080"
|
|
|
|
def test_perplexity_proxy_from_perplexity_proxy_env(self):
|
|
with patch.dict(os.environ, {"PERPLEXITY_PROXY": "http://pplx-proxy:8080"}):
|
|
adapter = PerplexityAdapter(api_key="test-key")
|
|
assert adapter.proxy == "http://pplx-proxy:8080"
|
|
|
|
def test_perplexity_fallback_to_https_proxy_env(self):
|
|
with patch.dict(os.environ, {"HTTPS_PROXY": "http://fallback:3128"}, clear=False):
|
|
if "PERPLEXITY_PROXY" in os.environ:
|
|
del os.environ["PERPLEXITY_PROXY"]
|
|
adapter = PerplexityAdapter(api_key="test-key")
|
|
assert adapter.proxy == "http://fallback:3128"
|
|
|
|
|
|
class TestDeepSeekAdapter:
|
|
def test_deepseek_init_default(self):
|
|
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "ds-test-key"}):
|
|
adapter = DeepSeekAdapter()
|
|
assert adapter.api_key == "ds-test-key"
|
|
assert adapter.get_engine_type() == EngineType.DEEPSEEK
|
|
|
|
def test_deepseek_init_with_params(self):
|
|
adapter = DeepSeekAdapter(
|
|
api_key="custom-key",
|
|
model="deepseek-reasoner",
|
|
base_url="https://custom.api.com/v1",
|
|
)
|
|
assert adapter.api_key == "custom-key"
|
|
assert adapter._model == "deepseek-reasoner"
|
|
assert adapter._base_url == "https://custom.api.com/v1"
|
|
|
|
def test_deepseek_init_from_env(self):
|
|
with patch.dict(
|
|
os.environ,
|
|
{
|
|
"DEEPSEEK_API_KEY": "env-key",
|
|
"DEEPSEEK_MODEL": "deepseek-coder",
|
|
"DEEPSEEK_BASE_URL": "https://env.api.com/v1",
|
|
},
|
|
):
|
|
adapter = DeepSeekAdapter()
|
|
assert adapter.api_key == "env-key"
|
|
assert adapter._model == "deepseek-coder"
|
|
assert adapter._base_url == "https://env.api.com/v1"
|
|
|
|
def test_deepseek_default_model(self):
|
|
adapter = DeepSeekAdapter(api_key="test-key")
|
|
assert adapter._model == "deepseek-chat"
|
|
|
|
def test_deepseek_default_base_url(self):
|
|
adapter = DeepSeekAdapter(api_key="test-key")
|
|
assert adapter._base_url == "https://api.deepseek.com/v1"
|
|
|
|
def test_deepseek_endpoint(self):
|
|
adapter = DeepSeekAdapter(api_key="test-key")
|
|
assert adapter._endpoint == "https://api.deepseek.com/v1/chat/completions"
|
|
|
|
def test_deepseek_engine_type(self):
|
|
adapter = DeepSeekAdapter(api_key="test-key")
|
|
assert adapter.get_engine_type() == EngineType.DEEPSEEK
|
|
|
|
def test_deepseek_proxy_parameter(self):
|
|
adapter = DeepSeekAdapter(api_key="test-key", proxy="http://proxy:8080")
|
|
assert adapter.proxy == "http://proxy:8080"
|
|
|
|
def test_deepseek_proxy_from_deepseek_proxy_env(self):
|
|
with patch.dict(os.environ, {"DEEPSEEK_PROXY": "http://ds-proxy:8080"}):
|
|
adapter = DeepSeekAdapter(api_key="test-key")
|
|
assert adapter.proxy == "http://ds-proxy:8080"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deepseek_query_success(self):
|
|
adapter = DeepSeekAdapter(api_key="test-key")
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [
|
|
{"message": {"content": "BrandX is a leading AI company with great models."}}
|
|
],
|
|
"model": "deepseek-chat",
|
|
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
|
}
|
|
|
|
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
result = await adapter.query(
|
|
query="best AI companies",
|
|
brand_name="BrandX",
|
|
competitor_names=["CompetitorY"],
|
|
)
|
|
|
|
assert isinstance(result, AIQueryResult)
|
|
assert result.engine_type == EngineType.DEEPSEEK
|
|
assert result.query == "best AI companies"
|
|
assert "BrandX" in result.raw_response
|
|
assert result.has_brand_citation is True
|
|
assert result.has_competitor_citation is False
|
|
assert result.response_time_ms >= 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deepseek_query_with_competitor(self):
|
|
adapter = DeepSeekAdapter(api_key="test-key")
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [
|
|
{"message": {"content": "BrandX and CompetitorY are both strong AI companies."}}
|
|
],
|
|
"model": "deepseek-chat",
|
|
}
|
|
|
|
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
result = await adapter.query(
|
|
query="AI comparison",
|
|
brand_name="BrandX",
|
|
competitor_names=["CompetitorY"],
|
|
)
|
|
|
|
assert result.has_brand_citation is True
|
|
assert result.has_competitor_citation is True
|
|
assert len(result.competitor_contexts) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deepseek_query_with_rate_limiter(self):
|
|
mock_limiter = AsyncMock()
|
|
adapter = DeepSeekAdapter(api_key="test-key", rate_limiter=mock_limiter)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [{"message": {"content": "Some response"}}],
|
|
"model": "deepseek-chat",
|
|
}
|
|
|
|
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
await adapter.query(query="test", brand_name="BrandX")
|
|
|
|
mock_limiter.acquire.assert_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deepseek_query_api_error(self):
|
|
adapter = DeepSeekAdapter(api_key="test-key")
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 401
|
|
mock_response.text = "Unauthorized"
|
|
|
|
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
|
|
with pytest.raises(Exception):
|
|
await adapter.query(query="test", brand_name="BrandX")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deepseek_query_transport_error(self):
|
|
adapter = DeepSeekAdapter(api_key="test-key")
|
|
|
|
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.side_effect = httpx.TransportError("Connection failed")
|
|
|
|
with pytest.raises(Exception):
|
|
await adapter.query(query="test", brand_name="BrandX")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deepseek_query_timeout(self):
|
|
adapter = DeepSeekAdapter(api_key="test-key")
|
|
|
|
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.side_effect = httpx.TimeoutException("Request timed out")
|
|
|
|
with pytest.raises(Exception):
|
|
await adapter.query(query="test", brand_name="BrandX")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deepseek_close(self):
|
|
adapter = DeepSeekAdapter(api_key="test-key")
|
|
await adapter.close()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deepseek_context_manager(self):
|
|
async with DeepSeekAdapter(api_key="test-key") as adapter:
|
|
assert adapter.api_key == "test-key"
|