geo/backend/tests/test_services/test_llm_provider.py

205 lines
7.4 KiB
Python

"""LLM Provider单元测试"""
import json
import os
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def set_openai_api_key(monkeypatch):
"""确保 OPENAI_API_KEY 存在(避免构造器抛异常)"""
monkeypatch.setenv("OPENAI_API_KEY", "test-key-fixture")
monkeypatch.setenv("DEEPSEEK_API_KEY", "test-deepseek-key")
# ---------------------------------------------------------------------------
# LLMFactory 测试
# ---------------------------------------------------------------------------
class TestLLMFactory:
def test_factory_create_openai(self):
"""LLMFactory.create('openai') 返回 OpenAIProvider 实例"""
from app.services.llm.factory import LLMFactory
from app.services.llm.openai_provider import OpenAIProvider
provider = LLMFactory.create("openai")
assert isinstance(provider, OpenAIProvider)
assert provider.provider_name == "openai"
def test_factory_create_deepseek(self):
"""LLMFactory.create('deepseek') 返回 DeepSeekProvider 实例"""
from app.services.llm.factory import LLMFactory
from app.services.llm.deepseek_provider import DeepSeekProvider
provider = LLMFactory.create("deepseek")
assert isinstance(provider, DeepSeekProvider)
assert provider.provider_name == "deepseek"
def test_factory_default_from_env(self, monkeypatch):
"""环境变量 DEFAULT_LLM_PROVIDER 控制默认 provider"""
from app.services.llm.factory import LLMFactory
from app.services.llm.deepseek_provider import DeepSeekProvider
monkeypatch.setenv("DEFAULT_LLM_PROVIDER", "deepseek")
provider = LLMFactory.create()
assert isinstance(provider, DeepSeekProvider)
def test_factory_unknown_provider_raises(self):
"""未知 provider 抛出 LLMError"""
from app.services.llm.factory import LLMFactory
from app.services.llm.base import LLMError
with pytest.raises(LLMError):
LLMFactory.create("nonexistent_provider_xyz")
def test_factory_list_providers(self):
"""list_providers 包含 openai 和 deepseek"""
from app.services.llm.factory import LLMFactory
providers = LLMFactory.list_providers()
assert "openai" in providers
assert "deepseek" in providers
# ---------------------------------------------------------------------------
# OpenAIProvider 测试
# ---------------------------------------------------------------------------
def _make_chat_response(content="Hello", model="gpt-4o-mini"):
"""构造 OpenAI chat completions 响应体"""
return {
"id": "chatcmpl-test",
"object": "chat.completion",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"model": model,
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
},
}
class TestOpenAIProvider:
@pytest.fixture
def provider(self):
from app.services.llm.openai_provider import OpenAIProvider
return OpenAIProvider(api_key="test-key", model="gpt-4o-mini")
@pytest.mark.asyncio
async def test_openai_chat_success(self, provider):
"""Mock httpx 正常返回 LLMResponse"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = _make_chat_response("test content")
with patch.object(provider._client, "post", new=AsyncMock(return_value=mock_response)):
from app.services.llm.base import LLMResponse
result = await provider.chat([{"role": "user", "content": "Hello"}])
assert isinstance(result, LLMResponse)
assert result.content == "test content"
assert result.model == "gpt-4o-mini"
assert "prompt_tokens" in result.usage
assert "completion_tokens" in result.usage
assert "total_tokens" in result.usage
@pytest.mark.asyncio
async def test_openai_chat_retry_on_429(self, provider):
"""Mock 429 响应,验证重试(最终耗尽后抛异常)"""
mock_429 = MagicMock()
mock_429.status_code = 429
mock_429.text = "Rate limit exceeded"
mock_429.headers = {}
with patch.object(provider._client, "post", new=AsyncMock(return_value=mock_429)):
with patch("asyncio.sleep", new=AsyncMock()): # 不真实等待
from app.services.llm.base import LLMError
with pytest.raises(LLMError) as exc_info:
await provider.chat([{"role": "user", "content": "Hi"}])
assert exc_info.value.status_code == 429
@pytest.mark.asyncio
async def test_openai_chat_error_on_401(self, provider):
"""Mock 401 响应,直接抛 LLMError 不重试"""
mock_401 = MagicMock()
mock_401.status_code = 401
mock_401.text = "Unauthorized"
call_count = 0
async def mock_post(*args, **kwargs):
nonlocal call_count
call_count += 1
return mock_401
with patch.object(provider._client, "post", new=mock_post):
from app.services.llm.base import LLMError
with pytest.raises(LLMError) as exc_info:
await provider.chat([{"role": "user", "content": "Hi"}])
assert exc_info.value.status_code == 401
# 401 不可重试,只调用一次
assert call_count == 1
@pytest.mark.asyncio
async def test_openai_stream_parsing(self, provider):
"""Mock SSE 流,验证逐 token yield"""
sse_lines = [
'data: {"choices": [{"delta": {"content": "Hello"}}]}',
'data: {"choices": [{"delta": {"content": " World"}}]}',
'data: [DONE]',
]
class MockStreamResponse:
status_code = 200
headers = {}
async def aread(self):
return b""
async def aiter_lines(self):
for line in sse_lines:
yield line
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
with patch.object(provider._client, "stream", return_value=MockStreamResponse()):
tokens = []
async for token in provider.chat_stream([{"role": "user", "content": "Hi"}]):
tokens.append(token)
assert tokens == ["Hello", " World"]
def test_llm_response_structure(self):
"""验证 LLMResponse 字段完整"""
from app.services.llm.base import LLMResponse
resp = LLMResponse(content="test", model="gpt-4o-mini")
assert hasattr(resp, "content")
assert hasattr(resp, "model")
assert hasattr(resp, "usage")
assert resp.usage["prompt_tokens"] == 0
assert resp.usage["completion_tokens"] == 0
assert resp.usage["total_tokens"] == 0
def test_provider_properties(self, provider):
"""验证 provider 属性"""
assert provider.provider_name == "openai"
assert provider.model_name == "gpt-4o-mini"
assert provider.max_context_length == 128_000