"""Gemini Provider 测试""" import json from unittest.mock import AsyncMock, MagicMock import httpx import pytest from pytest_httpx import HTTPXMock from agentkit.core.exceptions import LLMProviderError from agentkit.llm.protocol import LLMRequest, LLMResponse, StreamChunk, TokenUsage from agentkit.llm.providers.gemini import GeminiProvider # Base URL for Gemini API (without key param - pytest-httpx matches without query) _GEMINI_BASE = "https://generativelanguage.googleapis.com" class TestGeminiMessageConversion: """消息格式转换测试""" def setup_method(self): self.provider = GeminiProvider(api_key="test-key") def test_system_message_extracted_as_system_instruction(self): """system 消息应被提取为 systemInstruction""" messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello"}, ] system_instruction, contents = self.provider._convert_messages(messages) assert system_instruction == {"parts": [{"text": "You are a helpful assistant."}]} assert len(contents) == 1 assert contents[0]["role"] == "user" assert contents[0]["parts"] == [{"text": "Hello"}] def test_text_messages_converted_to_contents(self): """普通文本消息应转换为 Gemini contents""" messages = [ {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}, {"role": "user", "content": "How are you?"}, ] system_instruction, contents = self.provider._convert_messages(messages) assert system_instruction is None assert len(contents) == 3 assert contents[0] == {"role": "user", "parts": [{"text": "Hi"}]} assert contents[1] == {"role": "model", "parts": [{"text": "Hello!"}]} assert contents[2] == {"role": "user", "parts": [{"text": "How are you?"}]} def test_assistant_tool_calls_converted(self): """assistant 的 tool_calls 应转换为 functionCall parts""" messages = [ {"role": "user", "content": "What's the weather?"}, { "role": "assistant", "content": None, "tool_calls": [ { "id": "call_123", "type": "function", "function": { "name": "get_weather", "arguments": '{"city": "Beijing"}', }, } ], }, ] _, contents = self.provider._convert_messages(messages) assert len(contents) == 2 model_msg = contents[1] assert model_msg["role"] == "model" assert len(model_msg["parts"]) == 1 assert "functionCall" in model_msg["parts"][0] assert model_msg["parts"][0]["functionCall"]["name"] == "get_weather" assert model_msg["parts"][0]["functionCall"]["args"] == {"city": "Beijing"} def test_assistant_tool_calls_with_text(self): """assistant 同时有文本和 tool_calls""" messages = [ { "role": "assistant", "content": "Let me check that.", "tool_calls": [ { "id": "call_456", "type": "function", "function": { "name": "search", "arguments": '{"q": "test"}', }, } ], }, ] _, contents = self.provider._convert_messages(messages) parts = contents[0]["parts"] assert len(parts) == 2 assert parts[0] == {"text": "Let me check that."} assert "functionCall" in parts[1] def test_tool_result_converted_to_function_response(self): """tool 角色消息应转换为 functionResponse parts""" messages = [ { "role": "tool", "tool_call_id": "call_123", "name": "get_weather", "content": "Sunny, 25°C", }, ] _, contents = self.provider._convert_messages(messages) assert len(contents) == 1 msg = contents[0] assert msg["role"] == "user" assert len(msg["parts"]) == 1 assert "functionResponse" in msg["parts"][0] assert msg["parts"][0]["functionResponse"]["name"] == "get_weather" assert msg["parts"][0]["functionResponse"]["response"]["content"] == "Sunny, 25°C" def test_user_with_tool_call_id_converted(self): """user 消息带 tool_call_id 也应转换为 functionResponse""" messages = [ { "role": "user", "tool_call_id": "call_789", "content": "Result data", }, ] _, contents = self.provider._convert_messages(messages) msg = contents[0] assert msg["role"] == "user" assert "functionResponse" in msg["parts"][0] def test_no_system_message(self): """没有 system 消息时返回 None""" messages = [ {"role": "user", "content": "Hello"}, ] system_instruction, _ = self.provider._convert_messages(messages) assert system_instruction is None class TestGeminiToolConversion: """工具格式转换测试""" def setup_method(self): self.provider = GeminiProvider(api_key="test-key") def test_convert_openai_tools_to_gemini(self): """OpenAI function 格式应转换为 Gemini functionDeclarations""" tools = [ { "type": "function", "function": { "name": "get_weather", "description": "Get weather for a city", "parameters": { "type": "object", "properties": {"city": {"type": "string"}}, }, }, } ] result = self.provider._convert_tools(tools) assert len(result) == 1 assert "functionDeclarations" in result[0] declarations = result[0]["functionDeclarations"] assert len(declarations) == 1 assert declarations[0]["name"] == "get_weather" assert declarations[0]["description"] == "Get weather for a city" assert declarations[0]["parameters"] == { "type": "object", "properties": {"city": {"type": "string"}}, } def test_convert_empty_tools(self): """空工具列表应返回空列表""" result = self.provider._convert_tools([]) assert result == [] def test_convert_tool_choice_auto(self): """tool_choice=auto 应转换为 Gemini AUTO 模式""" result = self.provider._convert_tool_choice("auto") assert result == {"functionCallingConfig": {"mode": "AUTO"}} def test_convert_tool_choice_required(self): """tool_choice=required 应转换为 Gemini ANY 模式""" result = self.provider._convert_tool_choice("required") assert result == {"functionCallingConfig": {"mode": "ANY"}} def test_convert_tool_choice_none(self): """tool_choice=none 应转换为 Gemini NONE 模式""" result = self.provider._convert_tool_choice("none") assert result == {"functionCallingConfig": {"mode": "NONE"}} def test_convert_tool_choice_specific_tool(self): """指定工具名的 tool_choice 应转换为 Gemini AUTO 模式""" result = self.provider._convert_tool_choice("get_weather") assert result == {"functionCallingConfig": {"mode": "AUTO"}} class TestGeminiResponseParsing: """响应解析测试""" def setup_method(self): self.provider = GeminiProvider(api_key="test-key") def test_parse_text_response(self): """解析纯文本响应""" data = { "candidates": [ { "content": { "parts": [{"text": "Hello! How can I help?"}], "role": "model", }, "finishReason": "STOP", } ], "usageMetadata": { "promptTokenCount": 10, "candidatesTokenCount": 6, "totalTokenCount": 16, }, } response = self.provider._parse_response(data, "gemini-2.0-flash") assert isinstance(response, LLMResponse) assert response.content == "Hello! How can I help?" assert response.usage.prompt_tokens == 10 assert response.usage.completion_tokens == 6 assert not response.has_tool_calls def test_parse_function_call_response(self): """解析包含 functionCall 的响应""" data = { "candidates": [ { "content": { "parts": [ {"text": "Let me check the weather."}, { "functionCall": { "name": "get_weather", "args": {"city": "Beijing"}, } }, ], "role": "model", }, "finishReason": "STOP", } ], "usageMetadata": { "promptTokenCount": 20, "candidatesTokenCount": 15, "totalTokenCount": 35, }, } response = self.provider._parse_response(data, "gemini-2.0-flash") assert response.content == "Let me check the weather." assert response.has_tool_calls assert len(response.tool_calls) == 1 assert response.tool_calls[0].name == "get_weather" assert response.tool_calls[0].arguments == {"city": "Beijing"} def test_parse_multiple_function_calls(self): """解析包含多个 functionCall 的响应""" data = { "candidates": [ { "content": { "parts": [ { "functionCall": { "name": "get_weather", "args": {"city": "Beijing"}, } }, { "functionCall": { "name": "get_weather", "args": {"city": "Shanghai"}, } }, ], "role": "model", }, "finishReason": "STOP", } ], "usageMetadata": { "promptTokenCount": 25, "candidatesTokenCount": 20, "totalTokenCount": 45, }, } response = self.provider._parse_response(data, "gemini-2.0-flash") assert len(response.tool_calls) == 2 assert response.tool_calls[0].name == "get_weather" assert response.tool_calls[0].arguments == {"city": "Beijing"} assert response.tool_calls[1].arguments == {"city": "Shanghai"} def test_parse_empty_candidates(self): """解析空 candidates 响应""" data = { "candidates": [], "usageMetadata": { "promptTokenCount": 5, "candidatesTokenCount": 0, }, } response = self.provider._parse_response(data, "gemini-2.0-flash") assert response.content == "" assert not response.has_tool_calls def test_parse_model_version_in_response(self): """响应中的 modelVersion 应作为 model 返回""" data = { "candidates": [ { "content": { "parts": [{"text": "Hi"}], "role": "model", }, "finishReason": "STOP", } ], "modelVersion": "gemini-2.0-flash-001", "usageMetadata": { "promptTokenCount": 5, "candidatesTokenCount": 2, }, } response = self.provider._parse_response(data, "gemini-2.0-flash") assert response.model == "gemini-2.0-flash-001" class TestGeminiChat: """chat() 方法集成测试 - 使用 mock client 而非 httpx_mock""" def _make_mock_response(self, status_code: int, json_data: dict): """Create a mock httpx response.""" response = MagicMock(spec=httpx.Response) response.status_code = status_code response.json = MagicMock(return_value=json_data) response.content = json.dumps(json_data).encode() return response async def test_chat_returns_llm_response(self): """chat 应返回 LLMResponse""" mock_response = self._make_mock_response(200, { "candidates": [ { "content": { "parts": [{"text": "Hello from Gemini!"}], "role": "model", }, "finishReason": "STOP", } ], "usageMetadata": { "promptTokenCount": 10, "candidatesTokenCount": 5, "totalTokenCount": 15, }, }) mock_client = MagicMock(spec=httpx.AsyncClient) mock_client.post = AsyncMock(return_value=mock_response) provider = GeminiProvider(api_key="test-key") provider._client = mock_client request = LLMRequest( messages=[{"role": "user", "content": "Hi"}], model="gemini-2.0-flash", ) response = await provider.chat(request) assert isinstance(response, LLMResponse) assert response.content == "Hello from Gemini!" assert response.usage.prompt_tokens == 10 assert response.usage.completion_tokens == 5 assert response.latency_ms > 0 async def test_chat_with_system_message(self): """system 消息应作为 systemInstruction 发送""" mock_response = self._make_mock_response(200, { "candidates": [ { "content": { "parts": [{"text": "I am a helpful assistant."}], "role": "model", }, "finishReason": "STOP", } ], "usageMetadata": { "promptTokenCount": 15, "candidatesTokenCount": 8, }, }) mock_client = MagicMock(spec=httpx.AsyncClient) mock_client.post = AsyncMock(return_value=mock_response) provider = GeminiProvider(api_key="test-key") provider._client = mock_client request = LLMRequest( messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Who are you?"}, ], model="gemini-2.0-flash", ) response = await provider.chat(request) assert response.content == "I am a helpful assistant." # Verify the request payload call_args = mock_client.post.call_args request_body = call_args.kwargs.get("json", call_args[1].get("json", {})) assert "systemInstruction" in request_body assert request_body["systemInstruction"]["parts"][0]["text"] == "You are a helpful assistant." # System should NOT be in contents for msg in request_body["contents"]: assert msg["role"] != "system" async def test_chat_with_tools(self): """带工具的请求应正确转换格式""" mock_response = self._make_mock_response(200, { "candidates": [ { "content": { "parts": [ { "functionCall": { "name": "get_weather", "args": {"city": "Tokyo"}, } } ], "role": "model", }, "finishReason": "STOP", } ], "usageMetadata": { "promptTokenCount": 30, "candidatesTokenCount": 20, }, }) mock_client = MagicMock(spec=httpx.AsyncClient) mock_client.post = AsyncMock(return_value=mock_response) provider = GeminiProvider(api_key="test-key") provider._client = mock_client request = LLMRequest( messages=[{"role": "user", "content": "Weather in Tokyo?"}], model="gemini-2.0-flash", tools=[ { "type": "function", "function": { "name": "get_weather", "description": "Get weather", "parameters": { "type": "object", "properties": {"city": {"type": "string"}}, }, }, } ], ) response = await provider.chat(request) assert response.has_tool_calls assert response.tool_calls[0].name == "get_weather" assert response.tool_calls[0].arguments == {"city": "Tokyo"} # Verify request format call_args = mock_client.post.call_args request_body = call_args.kwargs.get("json", call_args[1].get("json", {})) assert "tools" in request_body assert "functionDeclarations" in request_body["tools"][0] assert request_body["tools"][0]["functionDeclarations"][0]["name"] == "get_weather" assert "toolConfig" in request_body assert request_body["toolConfig"]["functionCallingConfig"]["mode"] == "AUTO" async def test_chat_api_key_in_url(self): """API key 应通过 URL 参数传递""" mock_response = self._make_mock_response(200, { "candidates": [ { "content": { "parts": [{"text": "OK"}], "role": "model", }, "finishReason": "STOP", } ], "usageMetadata": { "promptTokenCount": 5, "candidatesTokenCount": 2, }, }) mock_client = MagicMock(spec=httpx.AsyncClient) mock_client.post = AsyncMock(return_value=mock_response) provider = GeminiProvider(api_key="my-secret-key") provider._client = mock_client request = LLMRequest( messages=[{"role": "user", "content": "Hi"}], model="gemini-2.0-flash", ) await provider.chat(request) call_args = mock_client.post.call_args url = call_args[0][0] if call_args[0] else call_args.kwargs.get("url", "") assert "key=my-secret-key" in url async def test_chat_with_custom_base_url(self): """自定义 base_url 应正确使用""" mock_response = self._make_mock_response(200, { "candidates": [ { "content": { "parts": [{"text": "Proxy response"}], "role": "model", }, "finishReason": "STOP", } ], "usageMetadata": { "promptTokenCount": 5, "candidatesTokenCount": 3, }, }) mock_client = MagicMock(spec=httpx.AsyncClient) mock_client.post = AsyncMock(return_value=mock_response) provider = GeminiProvider( api_key="test-key", base_url="https://custom-proxy.example.com", ) provider._client = mock_client request = LLMRequest( messages=[{"role": "user", "content": "Hi"}], model="gemini-2.0-flash", ) response = await provider.chat(request) assert response.content == "Proxy response" call_args = mock_client.post.call_args url = call_args[0][0] if call_args[0] else call_args.kwargs.get("url", "") assert "custom-proxy.example.com" in url class TestGeminiStreaming: """chat_stream() 方法测试""" def _make_stream_response(self, sse_lines: list[str]): """Create a mock httpx streaming response context manager.""" response = MagicMock() response.status_code = 200 async def aiter_lines(): for line in sse_lines: yield line response.aiter_lines = aiter_lines response.aread = AsyncMock(return_value=b"") context = MagicMock() context.__aenter__ = AsyncMock(return_value=response) context.__aexit__ = AsyncMock(return_value=False) return context async def test_stream_text_response(self): """流式文本响应应正确解析""" sse_lines = [ 'data: {"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3,"totalTokenCount":8}}', '', 'data: {"candidates":[{"content":{"parts":[{"text":" world"}],"role":"model"},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":5,"totalTokenCount":10}}', '', ] mock_client = MagicMock() mock_client.stream = MagicMock(return_value=self._make_stream_response(sse_lines)) provider = GeminiProvider(api_key="test-key") provider._client = mock_client request = LLMRequest( messages=[{"role": "user", "content": "Hi"}], model="gemini-2.0-flash", ) chunks = [] async for chunk in provider.chat_stream(request): chunks.append(chunk) text_chunks = [c for c in chunks if c.content] assert len(text_chunks) == 2 assert text_chunks[0].content == "Hello" assert text_chunks[1].content == " world" async def test_stream_function_call_response(self): """流式 functionCall 响应应正确解析""" sse_lines = [ 'data: {"candidates":[{"content":{"parts":[{"functionCall":{"name":"get_weather","args":{"city":"Paris"}}}],"role":"model"},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":20,"candidatesTokenCount":15}}', '', ] mock_client = MagicMock() mock_client.stream = MagicMock(return_value=self._make_stream_response(sse_lines)) provider = GeminiProvider(api_key="test-key") provider._client = mock_client request = LLMRequest( messages=[{"role": "user", "content": "Weather in Paris?"}], model="gemini-2.0-flash", tools=[ { "type": "function", "function": { "name": "get_weather", "description": "Get weather", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, }, } ], ) chunks = [] async for chunk in provider.chat_stream(request): chunks.append(chunk) final_chunks = [c for c in chunks if c.is_final] assert len(final_chunks) == 1 assert len(final_chunks[0].tool_calls) == 1 assert final_chunks[0].tool_calls[0].name == "get_weather" assert final_chunks[0].tool_calls[0].arguments == {"city": "Paris"} async def test_stream_with_usage_metadata(self): """流式响应应包含 usage 信息""" sse_lines = [ 'data: {"candidates":[{"content":{"parts":[{"text":"Hi"}],"role":"model"},"finishReason":"STOP"}]}', '', 'data: {"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}', '', ] mock_client = MagicMock() mock_client.stream = MagicMock(return_value=self._make_stream_response(sse_lines)) provider = GeminiProvider(api_key="test-key") provider._client = mock_client request = LLMRequest( messages=[{"role": "user", "content": "Hi"}], model="gemini-2.0-flash", ) chunks = [] async for chunk in provider.chat_stream(request): chunks.append(chunk) final_chunks = [c for c in chunks if c.is_final] assert len(final_chunks) == 1 assert final_chunks[0].usage is not None assert final_chunks[0].usage.prompt_tokens == 10 assert final_chunks[0].usage.completion_tokens == 5 async def test_stream_non_200_status(self): """流式请求非 200 状态应抛出 LLMProviderError""" response = MagicMock() response.status_code = 429 response.aread = AsyncMock(return_value=b'{"error":{"code":429,"message":"Rate limit exceeded"}}') context = MagicMock() context.__aenter__ = AsyncMock(return_value=response) context.__aexit__ = AsyncMock(return_value=False) mock_client = MagicMock() mock_client.stream = MagicMock(return_value=context) provider = GeminiProvider(api_key="test-key") provider._client = mock_client request = LLMRequest( messages=[{"role": "user", "content": "Hi"}], model="gemini-2.0-flash", ) with pytest.raises(LLMProviderError) as exc_info: async for _ in provider.chat_stream(request): pass assert "429" in str(exc_info.value) class TestGeminiErrors: """错误处理测试""" def _make_mock_response(self, status_code: int, json_data: dict): """Create a mock httpx response.""" response = MagicMock(spec=httpx.Response) response.status_code = status_code response.json = MagicMock(return_value=json_data) response.content = json.dumps(json_data).encode() return response async def test_400_bad_request(self): """400 错误应抛出 LLMProviderError""" mock_response = self._make_mock_response(400, { "error": { "code": 400, "message": "Invalid request", }, }) mock_client = MagicMock(spec=httpx.AsyncClient) mock_client.post = AsyncMock(return_value=mock_response) provider = GeminiProvider(api_key="test-key") provider._client = mock_client request = LLMRequest( messages=[{"role": "user", "content": "Hi"}], model="gemini-2.0-flash", ) with pytest.raises(LLMProviderError) as exc_info: await provider.chat(request) assert "gemini" in str(exc_info.value) assert "400" in str(exc_info.value) async def test_403_api_key_invalid(self): """403 错误应抛出 LLMProviderError""" mock_response = self._make_mock_response(403, { "error": { "code": 403, "message": "API key not valid", }, }) mock_client = MagicMock(spec=httpx.AsyncClient) mock_client.post = AsyncMock(return_value=mock_response) provider = GeminiProvider(api_key="bad-key") provider._client = mock_client request = LLMRequest( messages=[{"role": "user", "content": "Hi"}], model="gemini-2.0-flash", ) with pytest.raises(LLMProviderError) as exc_info: await provider.chat(request) assert "403" in str(exc_info.value) async def test_429_rate_limit(self): """429 错误应抛出 LLMProviderError""" mock_response = self._make_mock_response(429, { "error": { "code": 429, "message": "Rate limit exceeded", }, }) mock_client = MagicMock(spec=httpx.AsyncClient) mock_client.post = AsyncMock(return_value=mock_response) provider = GeminiProvider(api_key="test-key") provider._client = mock_client request = LLMRequest( messages=[{"role": "user", "content": "Hi"}], model="gemini-2.0-flash", ) with pytest.raises(LLMProviderError) as exc_info: await provider.chat(request) assert "429" in str(exc_info.value) async def test_500_server_error(self): """500 错误应抛出 LLMProviderError""" mock_response = self._make_mock_response(500, { "error": { "code": 500, "message": "Internal server error", }, }) mock_client = MagicMock(spec=httpx.AsyncClient) mock_client.post = AsyncMock(return_value=mock_response) provider = GeminiProvider(api_key="test-key") provider._client = mock_client request = LLMRequest( messages=[{"role": "user", "content": "Hi"}], model="gemini-2.0-flash", ) with pytest.raises(LLMProviderError): await provider.chat(request) async def test_503_service_unavailable(self): """503 错误应抛出 LLMProviderError""" mock_response = self._make_mock_response(503, { "error": { "code": 503, "message": "Service unavailable", }, }) mock_client = MagicMock(spec=httpx.AsyncClient) mock_client.post = AsyncMock(return_value=mock_response) provider = GeminiProvider(api_key="test-key") provider._client = mock_client request = LLMRequest( messages=[{"role": "user", "content": "Hi"}], model="gemini-2.0-flash", ) with pytest.raises(LLMProviderError) as exc_info: await provider.chat(request) assert "503" in str(exc_info.value) async def test_network_error(self): """网络错误应抛出 LLMProviderError""" mock_client = MagicMock(spec=httpx.AsyncClient) mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused")) provider = GeminiProvider(api_key="test-key") provider._client = mock_client request = LLMRequest( messages=[{"role": "user", "content": "Hi"}], model="gemini-2.0-flash", ) with pytest.raises(LLMProviderError): await provider.chat(request) async def test_error_does_not_expose_api_key(self): """错误消息不应暴露 API Key""" mock_response = self._make_mock_response(403, { "error": { "code": 403, "message": "API key not valid", }, }) mock_client = MagicMock(spec=httpx.AsyncClient) mock_client.post = AsyncMock(return_value=mock_response) provider = GeminiProvider(api_key="my-super-secret-key-12345") provider._client = mock_client request = LLMRequest( messages=[{"role": "user", "content": "Hi"}], model="gemini-2.0-flash", ) with pytest.raises(LLMProviderError) as exc_info: await provider.chat(request) assert "my-super-secret-key-12345" not in str(exc_info.value) class TestGeminiGetModelInfo: """get_model_info() 测试""" def test_returns_provider_and_model_info(self): provider = GeminiProvider( api_key="test-key", model="gemini-2.0-flash", max_output_tokens=8192, ) info = provider.get_model_info() assert info["provider"] == "gemini" assert info["model"] == "gemini-2.0-flash" assert info["max_output_tokens"] == 8192 def test_default_model_info(self): provider = GeminiProvider(api_key="test-key") info = provider.get_model_info() assert info["provider"] == "gemini" assert info["model"] == "gemini-2.0-flash" assert info["max_output_tokens"] == 4096 class TestGeminiLazyClient: """Lazy client 初始化测试""" def test_client_not_created_on_init(self): """初始化时不应创建 HTTP 客户端""" provider = GeminiProvider(api_key="test-key") assert provider._client is None def test_client_created_on_first_use(self): """首次使用时应创建 HTTP 客户端""" provider = GeminiProvider(api_key="test-key") client = provider._get_client() assert client is not None assert provider._client is not None def test_client_reused(self): """多次调用应复用同一客户端""" provider = GeminiProvider(api_key="test-key") client1 = provider._get_client() client2 = provider._get_client() assert client1 is client2 async def test_close_resets_client(self): """close 后客户端应被重置""" provider = GeminiProvider(api_key="test-key") _ = provider._get_client() assert provider._client is not None await provider.close() assert provider._client is None