205 lines
7.4 KiB
Python
205 lines
7.4 KiB
Python
"""LLM Provider单元测试"""
|
|
import json
|
|
import os
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
import httpx
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def set_openai_api_key(monkeypatch):
|
|
"""确保 OPENAI_API_KEY 存在(避免构造器抛异常)"""
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key-fixture")
|
|
monkeypatch.setenv("DEEPSEEK_API_KEY", "test-deepseek-key")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# LLMFactory 测试
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestLLMFactory:
|
|
def test_factory_create_openai(self):
|
|
"""LLMFactory.create('openai') 返回 OpenAIProvider 实例"""
|
|
from app.services.llm.factory import LLMFactory
|
|
from app.services.llm.openai_provider import OpenAIProvider
|
|
|
|
provider = LLMFactory.create("openai")
|
|
assert isinstance(provider, OpenAIProvider)
|
|
assert provider.provider_name == "openai"
|
|
|
|
def test_factory_create_deepseek(self):
|
|
"""LLMFactory.create('deepseek') 返回 DeepSeekProvider 实例"""
|
|
from app.services.llm.factory import LLMFactory
|
|
from app.services.llm.deepseek_provider import DeepSeekProvider
|
|
|
|
provider = LLMFactory.create("deepseek")
|
|
assert isinstance(provider, DeepSeekProvider)
|
|
assert provider.provider_name == "deepseek"
|
|
|
|
def test_factory_default_from_env(self, monkeypatch):
|
|
"""环境变量 DEFAULT_LLM_PROVIDER 控制默认 provider"""
|
|
from app.services.llm.factory import LLMFactory
|
|
from app.services.llm.deepseek_provider import DeepSeekProvider
|
|
|
|
monkeypatch.setenv("DEFAULT_LLM_PROVIDER", "deepseek")
|
|
provider = LLMFactory.create()
|
|
assert isinstance(provider, DeepSeekProvider)
|
|
|
|
def test_factory_unknown_provider_raises(self):
|
|
"""未知 provider 抛出 LLMError"""
|
|
from app.services.llm.factory import LLMFactory
|
|
from app.services.llm.base import LLMError
|
|
|
|
with pytest.raises(LLMError):
|
|
LLMFactory.create("nonexistent_provider_xyz")
|
|
|
|
def test_factory_list_providers(self):
|
|
"""list_providers 包含 openai 和 deepseek"""
|
|
from app.services.llm.factory import LLMFactory
|
|
|
|
providers = LLMFactory.list_providers()
|
|
assert "openai" in providers
|
|
assert "deepseek" in providers
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# OpenAIProvider 测试
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_chat_response(content="Hello", model="gpt-4o-mini"):
|
|
"""构造 OpenAI chat completions 响应体"""
|
|
return {
|
|
"id": "chatcmpl-test",
|
|
"object": "chat.completion",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {"role": "assistant", "content": content},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"model": model,
|
|
"usage": {
|
|
"prompt_tokens": 10,
|
|
"completion_tokens": 5,
|
|
"total_tokens": 15,
|
|
},
|
|
}
|
|
|
|
|
|
class TestOpenAIProvider:
|
|
@pytest.fixture
|
|
def provider(self):
|
|
from app.services.llm.openai_provider import OpenAIProvider
|
|
return OpenAIProvider(api_key="test-key", model="gpt-4o-mini")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_chat_success(self, provider):
|
|
"""Mock httpx 正常返回 LLMResponse"""
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = _make_chat_response("test content")
|
|
|
|
with patch.object(provider._client, "post", new=AsyncMock(return_value=mock_response)):
|
|
from app.services.llm.base import LLMResponse
|
|
result = await provider.chat([{"role": "user", "content": "Hello"}])
|
|
|
|
assert isinstance(result, LLMResponse)
|
|
assert result.content == "test content"
|
|
assert result.model == "gpt-4o-mini"
|
|
assert "prompt_tokens" in result.usage
|
|
assert "completion_tokens" in result.usage
|
|
assert "total_tokens" in result.usage
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_chat_retry_on_429(self, provider):
|
|
"""Mock 429 响应,验证重试(最终耗尽后抛异常)"""
|
|
mock_429 = MagicMock()
|
|
mock_429.status_code = 429
|
|
mock_429.text = "Rate limit exceeded"
|
|
mock_429.headers = {}
|
|
|
|
with patch.object(provider._client, "post", new=AsyncMock(return_value=mock_429)):
|
|
with patch("asyncio.sleep", new=AsyncMock()): # 不真实等待
|
|
from app.services.llm.base import LLMError
|
|
with pytest.raises(LLMError) as exc_info:
|
|
await provider.chat([{"role": "user", "content": "Hi"}])
|
|
assert exc_info.value.status_code == 429
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_chat_error_on_401(self, provider):
|
|
"""Mock 401 响应,直接抛 LLMError 不重试"""
|
|
mock_401 = MagicMock()
|
|
mock_401.status_code = 401
|
|
mock_401.text = "Unauthorized"
|
|
|
|
call_count = 0
|
|
|
|
async def mock_post(*args, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return mock_401
|
|
|
|
with patch.object(provider._client, "post", new=mock_post):
|
|
from app.services.llm.base import LLMError
|
|
with pytest.raises(LLMError) as exc_info:
|
|
await provider.chat([{"role": "user", "content": "Hi"}])
|
|
assert exc_info.value.status_code == 401
|
|
# 401 不可重试,只调用一次
|
|
assert call_count == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_stream_parsing(self, provider):
|
|
"""Mock SSE 流,验证逐 token yield"""
|
|
sse_lines = [
|
|
'data: {"choices": [{"delta": {"content": "Hello"}}]}',
|
|
'data: {"choices": [{"delta": {"content": " World"}}]}',
|
|
'data: [DONE]',
|
|
]
|
|
|
|
class MockStreamResponse:
|
|
status_code = 200
|
|
headers = {}
|
|
|
|
async def aread(self):
|
|
return b""
|
|
|
|
async def aiter_lines(self):
|
|
for line in sse_lines:
|
|
yield line
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, *args):
|
|
pass
|
|
|
|
with patch.object(provider._client, "stream", return_value=MockStreamResponse()):
|
|
tokens = []
|
|
async for token in provider.chat_stream([{"role": "user", "content": "Hi"}]):
|
|
tokens.append(token)
|
|
|
|
assert tokens == ["Hello", " World"]
|
|
|
|
def test_llm_response_structure(self):
|
|
"""验证 LLMResponse 字段完整"""
|
|
from app.services.llm.base import LLMResponse
|
|
|
|
resp = LLMResponse(content="test", model="gpt-4o-mini")
|
|
assert hasattr(resp, "content")
|
|
assert hasattr(resp, "model")
|
|
assert hasattr(resp, "usage")
|
|
assert resp.usage["prompt_tokens"] == 0
|
|
assert resp.usage["completion_tokens"] == 0
|
|
assert resp.usage["total_tokens"] == 0
|
|
|
|
def test_provider_properties(self, provider):
|
|
"""验证 provider 属性"""
|
|
assert provider.provider_name == "openai"
|
|
assert provider.model_name == "gpt-4o-mini"
|
|
assert provider.max_context_length == 128_000
|