"""LLM Provider单元测试""" import json import os import pytest from unittest.mock import AsyncMock, MagicMock, patch import httpx # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture(autouse=True) def set_openai_api_key(monkeypatch): """确保 OPENAI_API_KEY 存在(避免构造器抛异常)""" monkeypatch.setenv("OPENAI_API_KEY", "test-key-fixture") monkeypatch.setenv("DEEPSEEK_API_KEY", "test-deepseek-key") # --------------------------------------------------------------------------- # LLMFactory 测试 # --------------------------------------------------------------------------- class TestLLMFactory: def test_factory_create_openai(self): """LLMFactory.create('openai') 返回 OpenAIProvider 实例""" from app.services.llm.factory import LLMFactory from app.services.llm.openai_provider import OpenAIProvider provider = LLMFactory.create("openai") assert isinstance(provider, OpenAIProvider) assert provider.provider_name == "openai" def test_factory_create_deepseek(self): """LLMFactory.create('deepseek') 返回 DeepSeekProvider 实例""" from app.services.llm.factory import LLMFactory from app.services.llm.deepseek_provider import DeepSeekProvider provider = LLMFactory.create("deepseek") assert isinstance(provider, DeepSeekProvider) assert provider.provider_name == "deepseek" def test_factory_default_from_env(self, monkeypatch): """环境变量 DEFAULT_LLM_PROVIDER 控制默认 provider""" from app.services.llm.factory import LLMFactory from app.services.llm.deepseek_provider import DeepSeekProvider monkeypatch.setenv("DEFAULT_LLM_PROVIDER", "deepseek") provider = LLMFactory.create() assert isinstance(provider, DeepSeekProvider) def test_factory_unknown_provider_raises(self): """未知 provider 抛出 LLMError""" from app.services.llm.factory import LLMFactory from app.services.llm.base import LLMError with pytest.raises(LLMError): LLMFactory.create("nonexistent_provider_xyz") def test_factory_list_providers(self): """list_providers 包含 openai 和 deepseek""" from app.services.llm.factory import LLMFactory providers = LLMFactory.list_providers() assert "openai" in providers assert "deepseek" in providers # --------------------------------------------------------------------------- # OpenAIProvider 测试 # --------------------------------------------------------------------------- def _make_chat_response(content="Hello", model="gpt-4o-mini"): """构造 OpenAI chat completions 响应体""" return { "id": "chatcmpl-test", "object": "chat.completion", "choices": [ { "index": 0, "message": {"role": "assistant", "content": content}, "finish_reason": "stop", } ], "model": model, "usage": { "prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15, }, } class TestOpenAIProvider: @pytest.fixture def provider(self): from app.services.llm.openai_provider import OpenAIProvider return OpenAIProvider(api_key="test-key", model="gpt-4o-mini") @pytest.mark.asyncio async def test_openai_chat_success(self, provider): """Mock httpx 正常返回 LLMResponse""" mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = _make_chat_response("test content") with patch.object(provider._client, "post", new=AsyncMock(return_value=mock_response)): from app.services.llm.base import LLMResponse result = await provider.chat([{"role": "user", "content": "Hello"}]) assert isinstance(result, LLMResponse) assert result.content == "test content" assert result.model == "gpt-4o-mini" assert "prompt_tokens" in result.usage assert "completion_tokens" in result.usage assert "total_tokens" in result.usage @pytest.mark.asyncio async def test_openai_chat_retry_on_429(self, provider): """Mock 429 响应,验证重试(最终耗尽后抛异常)""" mock_429 = MagicMock() mock_429.status_code = 429 mock_429.text = "Rate limit exceeded" mock_429.headers = {} with patch.object(provider._client, "post", new=AsyncMock(return_value=mock_429)): with patch("asyncio.sleep", new=AsyncMock()): # 不真实等待 from app.services.llm.base import LLMError with pytest.raises(LLMError) as exc_info: await provider.chat([{"role": "user", "content": "Hi"}]) assert exc_info.value.status_code == 429 @pytest.mark.asyncio async def test_openai_chat_error_on_401(self, provider): """Mock 401 响应,直接抛 LLMError 不重试""" mock_401 = MagicMock() mock_401.status_code = 401 mock_401.text = "Unauthorized" call_count = 0 async def mock_post(*args, **kwargs): nonlocal call_count call_count += 1 return mock_401 with patch.object(provider._client, "post", new=mock_post): from app.services.llm.base import LLMError with pytest.raises(LLMError) as exc_info: await provider.chat([{"role": "user", "content": "Hi"}]) assert exc_info.value.status_code == 401 # 401 不可重试,只调用一次 assert call_count == 1 @pytest.mark.asyncio async def test_openai_stream_parsing(self, provider): """Mock SSE 流,验证逐 token yield""" sse_lines = [ 'data: {"choices": [{"delta": {"content": "Hello"}}]}', 'data: {"choices": [{"delta": {"content": " World"}}]}', 'data: [DONE]', ] class MockStreamResponse: status_code = 200 headers = {} async def aread(self): return b"" async def aiter_lines(self): for line in sse_lines: yield line async def __aenter__(self): return self async def __aexit__(self, *args): pass with patch.object(provider._client, "stream", return_value=MockStreamResponse()): tokens = [] async for token in provider.chat_stream([{"role": "user", "content": "Hi"}]): tokens.append(token) assert tokens == ["Hello", " World"] def test_llm_response_structure(self): """验证 LLMResponse 字段完整""" from app.services.llm.base import LLMResponse resp = LLMResponse(content="test", model="gpt-4o-mini") assert hasattr(resp, "content") assert hasattr(resp, "model") assert hasattr(resp, "usage") assert resp.usage["prompt_tokens"] == 0 assert resp.usage["completion_tokens"] == 0 assert resp.usage["total_tokens"] == 0 def test_provider_properties(self, provider): """验证 provider 属性""" assert provider.provider_name == "openai" assert provider.model_name == "gpt-4o-mini" assert provider.max_context_length == 128_000