geo/backend/tests/test_services/test_proxy_and_deepseek.py

289 lines
11 KiB
Python

import os
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
from app.services.ai_engine.chatgpt import ChatGPTAdapter
from app.services.ai_engine.deepseek import DeepSeekAdapter
from app.services.ai_engine.perplexity import PerplexityAdapter
class _ConcreteAdapter(AIEngineAdapter):
async def query(self, query, brand_name, competitor_names=None):
pass
def get_engine_type(self):
return EngineType.CHATGPT
def _get_env_key(self) -> str | None:
return None
class TestProxySupportBase:
def test_base_init_with_proxy(self):
adapter = _ConcreteAdapter(api_key="test-key", proxy="http://proxy:8080")
assert adapter.proxy == "http://proxy:8080"
def test_base_init_proxy_default_none(self):
adapter = _ConcreteAdapter(api_key="test-key")
assert adapter.proxy is None
def test_base_init_proxy_from_https_proxy_env(self):
with patch.dict(os.environ, {"HTTPS_PROXY": "http://env-proxy:3128"}):
adapter = _ConcreteAdapter(api_key="test-key")
assert adapter.proxy == "http://env-proxy:3128"
def test_base_init_proxy_from_https_proxy_lowercase_env(self):
with patch.dict(os.environ, {"https_proxy": "http://lower-proxy:3128"}, clear=False):
if "HTTPS_PROXY" in os.environ:
del os.environ["HTTPS_PROXY"]
adapter = _ConcreteAdapter(api_key="test-key")
assert adapter.proxy == "http://lower-proxy:3128"
def test_base_init_explicit_proxy_overrides_env(self):
with patch.dict(os.environ, {"HTTPS_PROXY": "http://env-proxy:3128"}):
adapter = _ConcreteAdapter(
api_key="test-key", proxy="http://explicit-proxy:8080"
)
assert adapter.proxy == "http://explicit-proxy:8080"
@pytest.mark.asyncio
async def test_get_client_with_proxy(self):
adapter = _ConcreteAdapter(api_key="test-key", proxy="http://proxy:8080")
client = await adapter._get_client()
assert isinstance(client, httpx.AsyncClient)
await adapter.close()
@pytest.mark.asyncio
async def test_get_client_without_proxy(self):
adapter = _ConcreteAdapter(api_key="test-key")
client = await adapter._get_client()
assert isinstance(client, httpx.AsyncClient)
await adapter.close()
class TestChatGPTProxy:
def test_chatgpt_proxy_parameter(self):
adapter = ChatGPTAdapter(api_key="test-key", proxy="http://proxy:8080")
assert adapter.proxy == "http://proxy:8080"
def test_chatgpt_proxy_from_openai_proxy_env(self):
with patch.dict(os.environ, {"OPENAI_PROXY": "http://openai-proxy:8080"}):
adapter = ChatGPTAdapter(api_key="test-key")
assert adapter.proxy == "http://openai-proxy:8080"
def test_chatgpt_explicit_proxy_overrides_env(self):
with patch.dict(os.environ, {"OPENAI_PROXY": "http://env-proxy:8080"}):
adapter = ChatGPTAdapter(
api_key="test-key", proxy="http://explicit:9090"
)
assert adapter.proxy == "http://explicit:9090"
def test_chatgpt_fallback_to_https_proxy_env(self):
with patch.dict(os.environ, {"HTTPS_PROXY": "http://fallback:3128"}, clear=False):
if "OPENAI_PROXY" in os.environ:
del os.environ["OPENAI_PROXY"]
adapter = ChatGPTAdapter(api_key="test-key")
assert adapter.proxy == "http://fallback:3128"
def test_chatgpt_no_proxy(self):
with patch.dict(os.environ, {}, clear=False):
for key in ("OPENAI_PROXY", "HTTPS_PROXY", "https_proxy"):
os.environ.pop(key, None)
adapter = ChatGPTAdapter(api_key="test-key")
assert adapter.proxy is None
class TestPerplexityProxy:
def test_perplexity_proxy_parameter(self):
adapter = PerplexityAdapter(api_key="test-key", proxy="http://proxy:8080")
assert adapter.proxy == "http://proxy:8080"
def test_perplexity_proxy_from_perplexity_proxy_env(self):
with patch.dict(os.environ, {"PERPLEXITY_PROXY": "http://pplx-proxy:8080"}):
adapter = PerplexityAdapter(api_key="test-key")
assert adapter.proxy == "http://pplx-proxy:8080"
def test_perplexity_fallback_to_https_proxy_env(self):
with patch.dict(os.environ, {"HTTPS_PROXY": "http://fallback:3128"}, clear=False):
if "PERPLEXITY_PROXY" in os.environ:
del os.environ["PERPLEXITY_PROXY"]
adapter = PerplexityAdapter(api_key="test-key")
assert adapter.proxy == "http://fallback:3128"
class TestDeepSeekAdapter:
def test_deepseek_init_default(self):
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "ds-test-key"}):
adapter = DeepSeekAdapter()
assert adapter.api_key == "ds-test-key"
assert adapter.get_engine_type() == EngineType.DEEPSEEK
def test_deepseek_init_with_params(self):
adapter = DeepSeekAdapter(
api_key="custom-key",
model="deepseek-reasoner",
base_url="https://custom.api.com/v1",
)
assert adapter.api_key == "custom-key"
assert adapter._model == "deepseek-reasoner"
assert adapter._base_url == "https://custom.api.com/v1"
def test_deepseek_init_from_env(self):
with patch.dict(
os.environ,
{
"DEEPSEEK_API_KEY": "env-key",
"DEEPSEEK_MODEL": "deepseek-coder",
"DEEPSEEK_BASE_URL": "https://env.api.com/v1",
},
):
adapter = DeepSeekAdapter()
assert adapter.api_key == "env-key"
assert adapter._model == "deepseek-coder"
assert adapter._base_url == "https://env.api.com/v1"
def test_deepseek_default_model(self):
adapter = DeepSeekAdapter(api_key="test-key")
assert adapter._model == "deepseek-chat"
def test_deepseek_default_base_url(self):
adapter = DeepSeekAdapter(api_key="test-key")
assert adapter._base_url == "https://api.deepseek.com/v1"
def test_deepseek_endpoint(self):
adapter = DeepSeekAdapter(api_key="test-key")
assert adapter._endpoint == "https://api.deepseek.com/v1/chat/completions"
def test_deepseek_engine_type(self):
adapter = DeepSeekAdapter(api_key="test-key")
assert adapter.get_engine_type() == EngineType.DEEPSEEK
def test_deepseek_proxy_parameter(self):
adapter = DeepSeekAdapter(api_key="test-key", proxy="http://proxy:8080")
assert adapter.proxy == "http://proxy:8080"
def test_deepseek_proxy_from_deepseek_proxy_env(self):
with patch.dict(os.environ, {"DEEPSEEK_PROXY": "http://ds-proxy:8080"}):
adapter = DeepSeekAdapter(api_key="test-key")
assert adapter.proxy == "http://ds-proxy:8080"
@pytest.mark.asyncio
async def test_deepseek_query_success(self):
adapter = DeepSeekAdapter(api_key="test-key")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{"message": {"content": "BrandX is a leading AI company with great models."}}
],
"model": "deepseek-chat",
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await adapter.query(
query="best AI companies",
brand_name="BrandX",
competitor_names=["CompetitorY"],
)
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.DEEPSEEK
assert result.query == "best AI companies"
assert "BrandX" in result.raw_response
assert result.has_brand_citation is True
assert result.has_competitor_citation is False
assert result.response_time_ms >= 0
@pytest.mark.asyncio
async def test_deepseek_query_with_competitor(self):
adapter = DeepSeekAdapter(api_key="test-key")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{"message": {"content": "BrandX and CompetitorY are both strong AI companies."}}
],
"model": "deepseek-chat",
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await adapter.query(
query="AI comparison",
brand_name="BrandX",
competitor_names=["CompetitorY"],
)
assert result.has_brand_citation is True
assert result.has_competitor_citation is True
assert len(result.competitor_contexts) == 1
@pytest.mark.asyncio
async def test_deepseek_query_with_rate_limiter(self):
mock_limiter = AsyncMock()
adapter = DeepSeekAdapter(api_key="test-key", rate_limiter=mock_limiter)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "Some response"}}],
"model": "deepseek-chat",
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
await adapter.query(query="test", brand_name="BrandX")
mock_limiter.acquire.assert_awaited()
@pytest.mark.asyncio
async def test_deepseek_query_api_error(self):
adapter = DeepSeekAdapter(api_key="test-key")
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.text = "Unauthorized"
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
with pytest.raises(Exception):
await adapter.query(query="test", brand_name="BrandX")
@pytest.mark.asyncio
async def test_deepseek_query_transport_error(self):
adapter = DeepSeekAdapter(api_key="test-key")
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.side_effect = httpx.TransportError("Connection failed")
with pytest.raises(Exception):
await adapter.query(query="test", brand_name="BrandX")
@pytest.mark.asyncio
async def test_deepseek_query_timeout(self):
adapter = DeepSeekAdapter(api_key="test-key")
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.side_effect = httpx.TimeoutException("Request timed out")
with pytest.raises(Exception):
await adapter.query(query="test", brand_name="BrandX")
@pytest.mark.asyncio
async def test_deepseek_close(self):
adapter = DeepSeekAdapter(api_key="test-key")
await adapter.close()
@pytest.mark.asyncio
async def test_deepseek_context_manager(self):
async with DeepSeekAdapter(api_key="test-key") as adapter:
assert adapter.api_key == "test-key"