"""U15 — LitellmProvider 单元测试。 用 ``unittest.mock.AsyncMock`` mock ``litellm.acompletion``,覆盖: - 6 个 provider 类型(openai/anthropic/gemini/doubao/wenxin/yuanbao)经 LiteLLM 调用 - 未知 provider_type 回退到 openai/ 前缀 - 自定义 base_url 透传 - tools 透传 + tool_calls 响应解析 - 流式 chunk 迭代 + 终止 chunk - 流式空消息处理 - LiteLLM 异常 → LLMProviderError - latency_ms 非负 """ from __future__ import annotations from types import SimpleNamespace from unittest.mock import AsyncMock, patch import pytest from agentkit.core.exceptions import LLMProviderError from agentkit.llm.protocol import LLMRequest, LLMResponse from agentkit.llm.providers.litellm_provider import ( LitellmProvider, _model_prefix_for, create_litellm_provider, ) # ---------------------------------------------------------------------- # 测试辅助:构造 OpenAI 格式的 fake response / chunk # ---------------------------------------------------------------------- def _fake_response( content: str = "Hello!", model: str = "gpt-4o-mini", prompt_tokens: int = 10, completion_tokens: int = 5, tool_calls: list | None = None, ) -> SimpleNamespace: """构造非流式 fake response(OpenAI ChatCompletion 格式)。""" message = SimpleNamespace(content=content, tool_calls=tool_calls) return SimpleNamespace( choices=[SimpleNamespace(message=message)], usage=SimpleNamespace( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ), model=model, ) def _fake_tool_call( tc_id: str = "call_1", name: str = "get_weather", arguments: str = '{"city": "Beijing"}', ) -> SimpleNamespace: """构造非流式 tool_call 对象。""" return SimpleNamespace( id=tc_id, function=SimpleNamespace(name=name, arguments=arguments), ) def _fake_stream_chunk( content: str = "", model: str = "gpt-4o-mini", tool_calls_delta: list | None = None, usage: SimpleNamespace | None = None, ) -> SimpleNamespace: """构造流式 chunk(OpenAI ChatCompletionChunk 格式)。""" delta = SimpleNamespace(content=content, tool_calls=tool_calls_delta) return SimpleNamespace( choices=[SimpleNamespace(delta=delta)], model=model, usage=usage, ) def _fake_stream_tool_call_delta( index: int = 0, tc_id: str | None = "call_1", name: str | None = "get_weather", arguments_fragment: str | None = None, ) -> SimpleNamespace: """构造流式 tool_call delta 片段。""" return SimpleNamespace( index=index, id=tc_id, function=SimpleNamespace(name=name, arguments=arguments_fragment), ) # ---------------------------------------------------------------------- # 1-7: provider_type → model_prefix 映射 + 基本 chat # ---------------------------------------------------------------------- @pytest.mark.parametrize( "provider_type, expected_prefix", [ ("openai", "openai/"), ("anthropic", "anthropic/"), ("gemini", "gemini/"), ("doubao", "volcengine/"), ("wenxin", "openai/"), # ponytail: wenxin 回退到 openai/ ("yuanbao", "hunyuan/"), ("unknown_type", "openai/"), # 未知 → openai/ 回退 ], ids=["openai", "anthropic", "gemini", "doubao", "wenxin", "yuanbao", "unknown"], ) def test_model_prefix_mapping(provider_type: str, expected_prefix: str): """验证 provider_type → LiteLLM model 前缀映射(含未知类型回退)。""" assert _model_prefix_for(provider_type) == expected_prefix @pytest.mark.parametrize( "provider_type, expected_prefix, model_name", [ ("openai", "openai/", "gpt-4o-mini"), ("anthropic", "anthropic/", "claude-sonnet-4-20250514"), ("gemini", "gemini/", "gemini-2.0-flash"), ("doubao", "volcengine/", "doubao-pro-32k"), ("wenxin", "openai/", "ernie-4.5-turbo-128k"), ("yuanbao", "hunyuan/", "hunyuan-pro"), ], ids=["openai", "anthropic", "gemini", "doubao", "wenxin", "yuanbao"], ) async def test_chat_via_litellm_for_each_provider_type( provider_type: str, expected_prefix: str, model_name: str, ): """6 个 provider 类型经 LiteLLM 调用 — 验证 model 前缀和响应翻译。""" provider = create_litellm_provider(provider_type, api_key="sk-test") assert provider._model_prefix == expected_prefix fake_resp = _fake_response(content="ok", model=model_name) with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)): response = await provider.chat( LLMRequest( messages=[{"role": "user", "content": "Hi"}], model=model_name, ) ) assert isinstance(response, LLMResponse) assert response.content == "ok" assert response.model == model_name assert response.usage.prompt_tokens == 10 assert response.usage.completion_tokens == 5 # ---------------------------------------------------------------------- # 8: 自定义 base_url 透传 # ---------------------------------------------------------------------- async def test_custom_base_url_passed_through(): """自定义 base_url 应作为 api_base 传给 litellm.acompletion。""" provider = create_litellm_provider( "openai", api_key="sk-test", base_url="https://custom.api/v1", ) mock_acompletion = AsyncMock(return_value=_fake_response()) with patch("litellm.acompletion", new=mock_acompletion): await provider.chat( LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o-mini") ) _, kwargs = mock_acompletion.call_args assert kwargs["api_base"] == "https://custom.api/v1" assert kwargs["api_key"] == "sk-test" assert kwargs["model"] == "openai/gpt-4o-mini" # ---------------------------------------------------------------------- # 9: tools 透传 # ---------------------------------------------------------------------- async def test_tools_passed_through(): """request.tools 应透传给 litellm.acompletion。""" provider = create_litellm_provider("openai", api_key="sk-test") mock_acompletion = AsyncMock(return_value=_fake_response()) with patch("litellm.acompletion", new=mock_acompletion): tools = [ { "type": "function", "function": { "name": "get_weather", "description": "Get weather", "parameters": {"type": "object", "properties": {}}, }, } ] await provider.chat( LLMRequest( messages=[{"role": "user", "content": "weather?"}], model="gpt-4o", tools=tools, tool_choice="auto", ) ) _, kwargs = mock_acompletion.call_args assert kwargs["tools"] == tools assert kwargs["tool_choice"] == "auto" # ---------------------------------------------------------------------- # 10: tool_calls 响应解析 # ---------------------------------------------------------------------- async def test_tool_calls_in_response_parsed(): """非流式响应中的 tool_calls 应解析成 LLMResponse.tool_calls。""" provider = create_litellm_provider("openai", api_key="sk-test") fake_resp = _fake_response( content="", tool_calls=[ _fake_tool_call( tc_id="call_abc", name="get_weather", arguments='{"city": "Beijing"}', ) ], ) with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)): response = await provider.chat( LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") ) assert response.has_tool_calls assert len(response.tool_calls) == 1 tc = response.tool_calls[0] assert tc.id == "call_abc" assert tc.name == "get_weather" assert tc.arguments == {"city": "Beijing"} async def test_tool_calls_with_dict_arguments(): """tool_call.arguments 为 dict 时直接采用。""" provider = create_litellm_provider("openai", api_key="sk-test") fake_resp = _fake_response( tool_calls=[ SimpleNamespace( id="call_1", function=SimpleNamespace(name="fn", arguments={"k": "v"}), ) ] ) with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)): response = await provider.chat( LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") ) assert response.tool_calls[0].arguments == {"k": "v"} # ---------------------------------------------------------------------- # 11: 流式 yields chunks # ---------------------------------------------------------------------- async def test_streaming_yields_chunks(): """流式应 yield 多个 is_final=False chunk + 一个 is_final=True 终止 chunk。""" provider = create_litellm_provider("openai", api_key="sk-test") chunks = [ _fake_stream_chunk(content="Hello"), _fake_stream_chunk(content=" world"), _fake_stream_chunk( usage=SimpleNamespace(prompt_tokens=8, completion_tokens=2), ), ] async def _fake_stream(**_kwargs): for c in chunks: yield c with patch("litellm.acompletion", new=_fake_stream): results = [] async for chunk in provider.chat_stream( LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o-mini") ): results.append(chunk) # 前 3 个是流式 chunk(is_final=False),最后一个是终止 chunk(is_final=True) assert len(results) == 4 assert results[0].content == "Hello" assert results[1].content == " world" assert all(not r.is_final for r in results[:3]) assert results[-1].is_final is True # 终止 chunk 含聚合 usage assert results[-1].usage is not None assert results[-1].usage.prompt_tokens == 8 assert results[-1].usage.completion_tokens == 2 async def test_streaming_aggregates_tool_calls(): """流式 tool_calls 片段应聚合到终止 chunk。""" provider = create_litellm_provider("openai", api_key="sk-test") chunks = [ _fake_stream_chunk( tool_calls_delta=[ _fake_stream_tool_call_delta( index=0, tc_id="call_1", name="get_weather", arguments_fragment='{"city":', ) ] ), _fake_stream_chunk( tool_calls_delta=[ _fake_stream_tool_call_delta( index=0, tc_id=None, name=None, arguments_fragment=' "Beijing"}', ) ] ), ] async def _fake_stream(**_kwargs): for c in chunks: yield c with patch("litellm.acompletion", new=_fake_stream): results = [] async for chunk in provider.chat_stream( LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") ): results.append(chunk) final = results[-1] assert final.is_final assert len(final.tool_calls) == 1 assert final.tool_calls[0].id == "call_1" assert final.tool_calls[0].name == "get_weather" assert final.tool_calls[0].arguments == {"city": "Beijing"} # ---------------------------------------------------------------------- # 12: 流式空消息处理 # ---------------------------------------------------------------------- async def test_streaming_empty_messages_handled(): """流式无 chunk 时仍应 yield 一个 is_final=True 空终止 chunk,不崩溃。""" provider = create_litellm_provider("openai", api_key="sk-test") async def _empty_stream(**_kwargs): return yield # 让函数成为 async generator(永不执行) with patch("litellm.acompletion", new=_empty_stream): results = [] async for chunk in provider.chat_stream( LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") ): results.append(chunk) # 至少 yield 一个终止 chunk assert len(results) >= 1 assert results[-1].is_final is True assert results[-1].content == "" # ---------------------------------------------------------------------- # 13: LiteLLM 异常 → LLMProviderError # ---------------------------------------------------------------------- async def test_litellm_exception_raises_provider_error(): """litellm.acompletion 抛异常时应包装成 LLMProviderError。""" provider = create_litellm_provider("openai", api_key="sk-test") with patch("litellm.acompletion", new=AsyncMock(side_effect=RuntimeError("boom"))): with pytest.raises(LLMProviderError) as exc_info: await provider.chat( LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") ) assert "boom" in str(exc_info.value) async def test_litellm_stream_exception_raises_provider_error(): """流式 litellm.acompletion 抛异常时应包装成 LLMProviderError。""" provider = create_litellm_provider("openai", api_key="sk-test") async def _failing_stream(**_kwargs): raise RuntimeError("stream boom") yield # unreachable,仅为 async generator 语法 with patch("litellm.acompletion", new=_failing_stream): with pytest.raises(LLMProviderError) as exc_info: async for _ in provider.chat_stream( LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") ): pass assert "stream boom" in str(exc_info.value) # ---------------------------------------------------------------------- # 14: latency 非负 # ---------------------------------------------------------------------- async def test_latency_measured_non_negative(): """LLMResponse.latency_ms 应为非负数。""" provider = create_litellm_provider("openai", api_key="sk-test") with patch("litellm.acompletion", new=AsyncMock(return_value=_fake_response())): response = await provider.chat( LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") ) assert response.latency_ms >= 0 # ---------------------------------------------------------------------- # 额外:num_retries=0 禁用 LiteLLM 自带 retry # ---------------------------------------------------------------------- async def test_litellm_num_retries_disabled(): """应传 num_retries=0 禁用 LiteLLM 自带 retry(由 gateway fallback 负责)。""" provider = create_litellm_provider("openai", api_key="sk-test") mock_acompletion = AsyncMock(return_value=_fake_response()) with patch("litellm.acompletion", new=mock_acompletion): await provider.chat( LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") ) _, kwargs = mock_acompletion.call_args assert kwargs["num_retries"] == 0 assert kwargs["stream"] is False async def test_streaming_passes_stream_true(): """chat_stream 应传 stream=True。""" provider = create_litellm_provider("openai", api_key="sk-test") captured: dict = {} async def _capturing_stream(**kwargs): captured.update(kwargs) async def _inner(): return yield return _inner() with patch("litellm.acompletion", new=_capturing_stream): async for _ in provider.chat_stream( LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") ): pass assert captured["stream"] is True # ---------------------------------------------------------------------- # 额外:timeout 透传 # ---------------------------------------------------------------------- async def test_timeout_passed_through(): """request.timeout 应透传给 litellm.acompletion。""" provider = create_litellm_provider("openai", api_key="sk-test") mock_acompletion = AsyncMock(return_value=_fake_response()) with patch("litellm.acompletion", new=mock_acompletion): await provider.chat( LLMRequest( messages=[{"role": "user", "content": "Hi"}], model="gpt-4o", timeout=42.0, ) ) _, kwargs = mock_acompletion.call_args assert kwargs["timeout"] == 42.0 # ---------------------------------------------------------------------- # 额外:LitellmProvider 直接构造(不经工厂) # ---------------------------------------------------------------------- async def test_direct_litellm_provider_construction(): """直接用 LitellmProvider(model_prefix=...) 构造也应工作。""" provider = LitellmProvider( model_prefix="anthropic/", api_key="sk-test", provider_type="anthropic", ) fake_resp = _fake_response(content="hi", model="claude-3") with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)): response = await provider.chat( LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="claude-3") ) assert response.content == "hi" # 验证 model 前缀正确拼接 mock_acompletion = AsyncMock(return_value=fake_resp) with patch("litellm.acompletion", new=mock_acompletion): await provider.chat( LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="claude-3") ) _, kwargs = mock_acompletion.call_args assert kwargs["model"] == "anthropic/claude-3" # ---------------------------------------------------------------------- # 额外:JSON tool_calls 解析容错 # ---------------------------------------------------------------------- async def test_tool_calls_invalid_json_arguments_wrapped(): """tool_call.arguments 为非法 JSON 时应包装成 {"raw": ...} 不崩溃。""" provider = create_litellm_provider("openai", api_key="sk-test") fake_resp = _fake_response( tool_calls=[ _fake_tool_call( tc_id="call_1", name="fn", arguments="not-valid-json{", ) ] ) with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)): response = await provider.chat( LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") ) assert response.tool_calls[0].arguments == {"raw": "not-valid-json{"}