955 lines
33 KiB
Python
955 lines
33 KiB
Python
"""Gemini Provider 测试"""
|
|
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import httpx
|
|
import pytest
|
|
from pytest_httpx import HTTPXMock
|
|
|
|
from agentkit.core.exceptions import LLMProviderError
|
|
from agentkit.llm.protocol import LLMRequest, LLMResponse, StreamChunk, TokenUsage
|
|
from agentkit.llm.providers.gemini import GeminiProvider
|
|
|
|
# Base URL for Gemini API (without key param - pytest-httpx matches without query)
|
|
_GEMINI_BASE = "https://generativelanguage.googleapis.com"
|
|
|
|
|
|
class TestGeminiMessageConversion:
|
|
"""消息格式转换测试"""
|
|
|
|
def setup_method(self):
|
|
self.provider = GeminiProvider(api_key="test-key")
|
|
|
|
def test_system_message_extracted_as_system_instruction(self):
|
|
"""system 消息应被提取为 systemInstruction"""
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "Hello"},
|
|
]
|
|
system_instruction, contents = self.provider._convert_messages(messages)
|
|
|
|
assert system_instruction == {"parts": [{"text": "You are a helpful assistant."}]}
|
|
assert len(contents) == 1
|
|
assert contents[0]["role"] == "user"
|
|
assert contents[0]["parts"] == [{"text": "Hello"}]
|
|
|
|
def test_text_messages_converted_to_contents(self):
|
|
"""普通文本消息应转换为 Gemini contents"""
|
|
messages = [
|
|
{"role": "user", "content": "Hi"},
|
|
{"role": "assistant", "content": "Hello!"},
|
|
{"role": "user", "content": "How are you?"},
|
|
]
|
|
system_instruction, contents = self.provider._convert_messages(messages)
|
|
|
|
assert system_instruction is None
|
|
assert len(contents) == 3
|
|
assert contents[0] == {"role": "user", "parts": [{"text": "Hi"}]}
|
|
assert contents[1] == {"role": "model", "parts": [{"text": "Hello!"}]}
|
|
assert contents[2] == {"role": "user", "parts": [{"text": "How are you?"}]}
|
|
|
|
def test_assistant_tool_calls_converted(self):
|
|
"""assistant 的 tool_calls 应转换为 functionCall parts"""
|
|
messages = [
|
|
{"role": "user", "content": "What's the weather?"},
|
|
{
|
|
"role": "assistant",
|
|
"content": None,
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_123",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"arguments": '{"city": "Beijing"}',
|
|
},
|
|
}
|
|
],
|
|
},
|
|
]
|
|
_, contents = self.provider._convert_messages(messages)
|
|
|
|
assert len(contents) == 2
|
|
model_msg = contents[1]
|
|
assert model_msg["role"] == "model"
|
|
assert len(model_msg["parts"]) == 1
|
|
assert "functionCall" in model_msg["parts"][0]
|
|
assert model_msg["parts"][0]["functionCall"]["name"] == "get_weather"
|
|
assert model_msg["parts"][0]["functionCall"]["args"] == {"city": "Beijing"}
|
|
|
|
def test_assistant_tool_calls_with_text(self):
|
|
"""assistant 同时有文本和 tool_calls"""
|
|
messages = [
|
|
{
|
|
"role": "assistant",
|
|
"content": "Let me check that.",
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_456",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "search",
|
|
"arguments": '{"q": "test"}',
|
|
},
|
|
}
|
|
],
|
|
},
|
|
]
|
|
_, contents = self.provider._convert_messages(messages)
|
|
|
|
parts = contents[0]["parts"]
|
|
assert len(parts) == 2
|
|
assert parts[0] == {"text": "Let me check that."}
|
|
assert "functionCall" in parts[1]
|
|
|
|
def test_tool_result_converted_to_function_response(self):
|
|
"""tool 角色消息应转换为 functionResponse parts"""
|
|
messages = [
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": "call_123",
|
|
"name": "get_weather",
|
|
"content": "Sunny, 25°C",
|
|
},
|
|
]
|
|
_, contents = self.provider._convert_messages(messages)
|
|
|
|
assert len(contents) == 1
|
|
msg = contents[0]
|
|
assert msg["role"] == "user"
|
|
assert len(msg["parts"]) == 1
|
|
assert "functionResponse" in msg["parts"][0]
|
|
assert msg["parts"][0]["functionResponse"]["name"] == "get_weather"
|
|
assert msg["parts"][0]["functionResponse"]["response"]["content"] == "Sunny, 25°C"
|
|
|
|
def test_user_with_tool_call_id_converted(self):
|
|
"""user 消息带 tool_call_id 也应转换为 functionResponse"""
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"tool_call_id": "call_789",
|
|
"content": "Result data",
|
|
},
|
|
]
|
|
_, contents = self.provider._convert_messages(messages)
|
|
|
|
msg = contents[0]
|
|
assert msg["role"] == "user"
|
|
assert "functionResponse" in msg["parts"][0]
|
|
|
|
def test_no_system_message(self):
|
|
"""没有 system 消息时返回 None"""
|
|
messages = [
|
|
{"role": "user", "content": "Hello"},
|
|
]
|
|
system_instruction, _ = self.provider._convert_messages(messages)
|
|
assert system_instruction is None
|
|
|
|
|
|
class TestGeminiToolConversion:
|
|
"""工具格式转换测试"""
|
|
|
|
def setup_method(self):
|
|
self.provider = GeminiProvider(api_key="test-key")
|
|
|
|
def test_convert_openai_tools_to_gemini(self):
|
|
"""OpenAI function 格式应转换为 Gemini functionDeclarations"""
|
|
tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"description": "Get weather for a city",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {"city": {"type": "string"}},
|
|
},
|
|
},
|
|
}
|
|
]
|
|
result = self.provider._convert_tools(tools)
|
|
|
|
assert len(result) == 1
|
|
assert "functionDeclarations" in result[0]
|
|
declarations = result[0]["functionDeclarations"]
|
|
assert len(declarations) == 1
|
|
assert declarations[0]["name"] == "get_weather"
|
|
assert declarations[0]["description"] == "Get weather for a city"
|
|
assert declarations[0]["parameters"] == {
|
|
"type": "object",
|
|
"properties": {"city": {"type": "string"}},
|
|
}
|
|
|
|
def test_convert_empty_tools(self):
|
|
"""空工具列表应返回空列表"""
|
|
result = self.provider._convert_tools([])
|
|
assert result == []
|
|
|
|
def test_convert_tool_choice_auto(self):
|
|
"""tool_choice=auto 应转换为 Gemini AUTO 模式"""
|
|
result = self.provider._convert_tool_choice("auto")
|
|
assert result == {"functionCallingConfig": {"mode": "AUTO"}}
|
|
|
|
def test_convert_tool_choice_required(self):
|
|
"""tool_choice=required 应转换为 Gemini ANY 模式"""
|
|
result = self.provider._convert_tool_choice("required")
|
|
assert result == {"functionCallingConfig": {"mode": "ANY"}}
|
|
|
|
def test_convert_tool_choice_none(self):
|
|
"""tool_choice=none 应转换为 Gemini NONE 模式"""
|
|
result = self.provider._convert_tool_choice("none")
|
|
assert result == {"functionCallingConfig": {"mode": "NONE"}}
|
|
|
|
def test_convert_tool_choice_specific_tool(self):
|
|
"""指定工具名的 tool_choice 应转换为 Gemini AUTO 模式"""
|
|
result = self.provider._convert_tool_choice("get_weather")
|
|
assert result == {"functionCallingConfig": {"mode": "AUTO"}}
|
|
|
|
|
|
class TestGeminiResponseParsing:
|
|
"""响应解析测试"""
|
|
|
|
def setup_method(self):
|
|
self.provider = GeminiProvider(api_key="test-key")
|
|
|
|
def test_parse_text_response(self):
|
|
"""解析纯文本响应"""
|
|
data = {
|
|
"candidates": [
|
|
{
|
|
"content": {
|
|
"parts": [{"text": "Hello! How can I help?"}],
|
|
"role": "model",
|
|
},
|
|
"finishReason": "STOP",
|
|
}
|
|
],
|
|
"usageMetadata": {
|
|
"promptTokenCount": 10,
|
|
"candidatesTokenCount": 6,
|
|
"totalTokenCount": 16,
|
|
},
|
|
}
|
|
response = self.provider._parse_response(data, "gemini-2.0-flash")
|
|
|
|
assert isinstance(response, LLMResponse)
|
|
assert response.content == "Hello! How can I help?"
|
|
assert response.usage.prompt_tokens == 10
|
|
assert response.usage.completion_tokens == 6
|
|
assert not response.has_tool_calls
|
|
|
|
def test_parse_function_call_response(self):
|
|
"""解析包含 functionCall 的响应"""
|
|
data = {
|
|
"candidates": [
|
|
{
|
|
"content": {
|
|
"parts": [
|
|
{"text": "Let me check the weather."},
|
|
{
|
|
"functionCall": {
|
|
"name": "get_weather",
|
|
"args": {"city": "Beijing"},
|
|
}
|
|
},
|
|
],
|
|
"role": "model",
|
|
},
|
|
"finishReason": "STOP",
|
|
}
|
|
],
|
|
"usageMetadata": {
|
|
"promptTokenCount": 20,
|
|
"candidatesTokenCount": 15,
|
|
"totalTokenCount": 35,
|
|
},
|
|
}
|
|
response = self.provider._parse_response(data, "gemini-2.0-flash")
|
|
|
|
assert response.content == "Let me check the weather."
|
|
assert response.has_tool_calls
|
|
assert len(response.tool_calls) == 1
|
|
assert response.tool_calls[0].name == "get_weather"
|
|
assert response.tool_calls[0].arguments == {"city": "Beijing"}
|
|
|
|
def test_parse_multiple_function_calls(self):
|
|
"""解析包含多个 functionCall 的响应"""
|
|
data = {
|
|
"candidates": [
|
|
{
|
|
"content": {
|
|
"parts": [
|
|
{
|
|
"functionCall": {
|
|
"name": "get_weather",
|
|
"args": {"city": "Beijing"},
|
|
}
|
|
},
|
|
{
|
|
"functionCall": {
|
|
"name": "get_weather",
|
|
"args": {"city": "Shanghai"},
|
|
}
|
|
},
|
|
],
|
|
"role": "model",
|
|
},
|
|
"finishReason": "STOP",
|
|
}
|
|
],
|
|
"usageMetadata": {
|
|
"promptTokenCount": 25,
|
|
"candidatesTokenCount": 20,
|
|
"totalTokenCount": 45,
|
|
},
|
|
}
|
|
response = self.provider._parse_response(data, "gemini-2.0-flash")
|
|
|
|
assert len(response.tool_calls) == 2
|
|
assert response.tool_calls[0].name == "get_weather"
|
|
assert response.tool_calls[0].arguments == {"city": "Beijing"}
|
|
assert response.tool_calls[1].arguments == {"city": "Shanghai"}
|
|
|
|
def test_parse_empty_candidates(self):
|
|
"""解析空 candidates 响应"""
|
|
data = {
|
|
"candidates": [],
|
|
"usageMetadata": {
|
|
"promptTokenCount": 5,
|
|
"candidatesTokenCount": 0,
|
|
},
|
|
}
|
|
response = self.provider._parse_response(data, "gemini-2.0-flash")
|
|
|
|
assert response.content == ""
|
|
assert not response.has_tool_calls
|
|
|
|
def test_parse_model_version_in_response(self):
|
|
"""响应中的 modelVersion 应作为 model 返回"""
|
|
data = {
|
|
"candidates": [
|
|
{
|
|
"content": {
|
|
"parts": [{"text": "Hi"}],
|
|
"role": "model",
|
|
},
|
|
"finishReason": "STOP",
|
|
}
|
|
],
|
|
"modelVersion": "gemini-2.0-flash-001",
|
|
"usageMetadata": {
|
|
"promptTokenCount": 5,
|
|
"candidatesTokenCount": 2,
|
|
},
|
|
}
|
|
response = self.provider._parse_response(data, "gemini-2.0-flash")
|
|
assert response.model == "gemini-2.0-flash-001"
|
|
|
|
|
|
class TestGeminiChat:
|
|
"""chat() 方法集成测试 - 使用 mock client 而非 httpx_mock"""
|
|
|
|
def _make_mock_response(self, status_code: int, json_data: dict):
|
|
"""Create a mock httpx response."""
|
|
response = MagicMock(spec=httpx.Response)
|
|
response.status_code = status_code
|
|
response.json = MagicMock(return_value=json_data)
|
|
response.content = json.dumps(json_data).encode()
|
|
return response
|
|
|
|
async def test_chat_returns_llm_response(self):
|
|
"""chat 应返回 LLMResponse"""
|
|
mock_response = self._make_mock_response(200, {
|
|
"candidates": [
|
|
{
|
|
"content": {
|
|
"parts": [{"text": "Hello from Gemini!"}],
|
|
"role": "model",
|
|
},
|
|
"finishReason": "STOP",
|
|
}
|
|
],
|
|
"usageMetadata": {
|
|
"promptTokenCount": 10,
|
|
"candidatesTokenCount": 5,
|
|
"totalTokenCount": 15,
|
|
},
|
|
})
|
|
|
|
mock_client = MagicMock(spec=httpx.AsyncClient)
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
|
|
provider = GeminiProvider(api_key="test-key")
|
|
provider._client = mock_client
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gemini-2.0-flash",
|
|
)
|
|
response = await provider.chat(request)
|
|
|
|
assert isinstance(response, LLMResponse)
|
|
assert response.content == "Hello from Gemini!"
|
|
assert response.usage.prompt_tokens == 10
|
|
assert response.usage.completion_tokens == 5
|
|
assert response.latency_ms > 0
|
|
|
|
async def test_chat_with_system_message(self):
|
|
"""system 消息应作为 systemInstruction 发送"""
|
|
mock_response = self._make_mock_response(200, {
|
|
"candidates": [
|
|
{
|
|
"content": {
|
|
"parts": [{"text": "I am a helpful assistant."}],
|
|
"role": "model",
|
|
},
|
|
"finishReason": "STOP",
|
|
}
|
|
],
|
|
"usageMetadata": {
|
|
"promptTokenCount": 15,
|
|
"candidatesTokenCount": 8,
|
|
},
|
|
})
|
|
|
|
mock_client = MagicMock(spec=httpx.AsyncClient)
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
|
|
provider = GeminiProvider(api_key="test-key")
|
|
provider._client = mock_client
|
|
|
|
request = LLMRequest(
|
|
messages=[
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "Who are you?"},
|
|
],
|
|
model="gemini-2.0-flash",
|
|
)
|
|
response = await provider.chat(request)
|
|
|
|
assert response.content == "I am a helpful assistant."
|
|
|
|
# Verify the request payload
|
|
call_args = mock_client.post.call_args
|
|
request_body = call_args.kwargs.get("json", call_args[1].get("json", {}))
|
|
assert "systemInstruction" in request_body
|
|
assert request_body["systemInstruction"]["parts"][0]["text"] == "You are a helpful assistant."
|
|
# System should NOT be in contents
|
|
for msg in request_body["contents"]:
|
|
assert msg["role"] != "system"
|
|
|
|
async def test_chat_with_tools(self):
|
|
"""带工具的请求应正确转换格式"""
|
|
mock_response = self._make_mock_response(200, {
|
|
"candidates": [
|
|
{
|
|
"content": {
|
|
"parts": [
|
|
{
|
|
"functionCall": {
|
|
"name": "get_weather",
|
|
"args": {"city": "Tokyo"},
|
|
}
|
|
}
|
|
],
|
|
"role": "model",
|
|
},
|
|
"finishReason": "STOP",
|
|
}
|
|
],
|
|
"usageMetadata": {
|
|
"promptTokenCount": 30,
|
|
"candidatesTokenCount": 20,
|
|
},
|
|
})
|
|
|
|
mock_client = MagicMock(spec=httpx.AsyncClient)
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
|
|
provider = GeminiProvider(api_key="test-key")
|
|
provider._client = mock_client
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Weather in Tokyo?"}],
|
|
model="gemini-2.0-flash",
|
|
tools=[
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"description": "Get weather",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {"city": {"type": "string"}},
|
|
},
|
|
},
|
|
}
|
|
],
|
|
)
|
|
response = await provider.chat(request)
|
|
|
|
assert response.has_tool_calls
|
|
assert response.tool_calls[0].name == "get_weather"
|
|
assert response.tool_calls[0].arguments == {"city": "Tokyo"}
|
|
|
|
# Verify request format
|
|
call_args = mock_client.post.call_args
|
|
request_body = call_args.kwargs.get("json", call_args[1].get("json", {}))
|
|
assert "tools" in request_body
|
|
assert "functionDeclarations" in request_body["tools"][0]
|
|
assert request_body["tools"][0]["functionDeclarations"][0]["name"] == "get_weather"
|
|
assert "toolConfig" in request_body
|
|
assert request_body["toolConfig"]["functionCallingConfig"]["mode"] == "AUTO"
|
|
|
|
async def test_chat_api_key_in_url(self):
|
|
"""API key 应通过 URL 参数传递"""
|
|
mock_response = self._make_mock_response(200, {
|
|
"candidates": [
|
|
{
|
|
"content": {
|
|
"parts": [{"text": "OK"}],
|
|
"role": "model",
|
|
},
|
|
"finishReason": "STOP",
|
|
}
|
|
],
|
|
"usageMetadata": {
|
|
"promptTokenCount": 5,
|
|
"candidatesTokenCount": 2,
|
|
},
|
|
})
|
|
|
|
mock_client = MagicMock(spec=httpx.AsyncClient)
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
|
|
provider = GeminiProvider(api_key="my-secret-key")
|
|
provider._client = mock_client
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gemini-2.0-flash",
|
|
)
|
|
await provider.chat(request)
|
|
|
|
call_args = mock_client.post.call_args
|
|
url = call_args[0][0] if call_args[0] else call_args.kwargs.get("url", "")
|
|
assert "key=my-secret-key" in url
|
|
|
|
async def test_chat_with_custom_base_url(self):
|
|
"""自定义 base_url 应正确使用"""
|
|
mock_response = self._make_mock_response(200, {
|
|
"candidates": [
|
|
{
|
|
"content": {
|
|
"parts": [{"text": "Proxy response"}],
|
|
"role": "model",
|
|
},
|
|
"finishReason": "STOP",
|
|
}
|
|
],
|
|
"usageMetadata": {
|
|
"promptTokenCount": 5,
|
|
"candidatesTokenCount": 3,
|
|
},
|
|
})
|
|
|
|
mock_client = MagicMock(spec=httpx.AsyncClient)
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
|
|
provider = GeminiProvider(
|
|
api_key="test-key",
|
|
base_url="https://custom-proxy.example.com",
|
|
)
|
|
provider._client = mock_client
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gemini-2.0-flash",
|
|
)
|
|
response = await provider.chat(request)
|
|
|
|
assert response.content == "Proxy response"
|
|
|
|
call_args = mock_client.post.call_args
|
|
url = call_args[0][0] if call_args[0] else call_args.kwargs.get("url", "")
|
|
assert "custom-proxy.example.com" in url
|
|
|
|
|
|
class TestGeminiStreaming:
|
|
"""chat_stream() 方法测试"""
|
|
|
|
def _make_stream_response(self, sse_lines: list[str]):
|
|
"""Create a mock httpx streaming response context manager."""
|
|
response = MagicMock()
|
|
response.status_code = 200
|
|
|
|
async def aiter_lines():
|
|
for line in sse_lines:
|
|
yield line
|
|
|
|
response.aiter_lines = aiter_lines
|
|
response.aread = AsyncMock(return_value=b"")
|
|
|
|
context = MagicMock()
|
|
context.__aenter__ = AsyncMock(return_value=response)
|
|
context.__aexit__ = AsyncMock(return_value=False)
|
|
return context
|
|
|
|
async def test_stream_text_response(self):
|
|
"""流式文本响应应正确解析"""
|
|
sse_lines = [
|
|
'data: {"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3,"totalTokenCount":8}}',
|
|
'',
|
|
'data: {"candidates":[{"content":{"parts":[{"text":" world"}],"role":"model"},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":5,"totalTokenCount":10}}',
|
|
'',
|
|
]
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
|
|
|
|
provider = GeminiProvider(api_key="test-key")
|
|
provider._client = mock_client
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gemini-2.0-flash",
|
|
)
|
|
|
|
chunks = []
|
|
async for chunk in provider.chat_stream(request):
|
|
chunks.append(chunk)
|
|
|
|
text_chunks = [c for c in chunks if c.content]
|
|
assert len(text_chunks) == 2
|
|
assert text_chunks[0].content == "Hello"
|
|
assert text_chunks[1].content == " world"
|
|
|
|
async def test_stream_function_call_response(self):
|
|
"""流式 functionCall 响应应正确解析"""
|
|
sse_lines = [
|
|
'data: {"candidates":[{"content":{"parts":[{"functionCall":{"name":"get_weather","args":{"city":"Paris"}}}],"role":"model"},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":20,"candidatesTokenCount":15}}',
|
|
'',
|
|
]
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
|
|
|
|
provider = GeminiProvider(api_key="test-key")
|
|
provider._client = mock_client
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Weather in Paris?"}],
|
|
model="gemini-2.0-flash",
|
|
tools=[
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"description": "Get weather",
|
|
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
|
|
},
|
|
}
|
|
],
|
|
)
|
|
|
|
chunks = []
|
|
async for chunk in provider.chat_stream(request):
|
|
chunks.append(chunk)
|
|
|
|
final_chunks = [c for c in chunks if c.is_final]
|
|
assert len(final_chunks) == 1
|
|
assert len(final_chunks[0].tool_calls) == 1
|
|
assert final_chunks[0].tool_calls[0].name == "get_weather"
|
|
assert final_chunks[0].tool_calls[0].arguments == {"city": "Paris"}
|
|
|
|
async def test_stream_with_usage_metadata(self):
|
|
"""流式响应应包含 usage 信息"""
|
|
sse_lines = [
|
|
'data: {"candidates":[{"content":{"parts":[{"text":"Hi"}],"role":"model"},"finishReason":"STOP"}]}',
|
|
'',
|
|
'data: {"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}',
|
|
'',
|
|
]
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
|
|
|
|
provider = GeminiProvider(api_key="test-key")
|
|
provider._client = mock_client
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gemini-2.0-flash",
|
|
)
|
|
|
|
chunks = []
|
|
async for chunk in provider.chat_stream(request):
|
|
chunks.append(chunk)
|
|
|
|
final_chunks = [c for c in chunks if c.is_final]
|
|
assert len(final_chunks) == 1
|
|
assert final_chunks[0].usage is not None
|
|
assert final_chunks[0].usage.prompt_tokens == 10
|
|
assert final_chunks[0].usage.completion_tokens == 5
|
|
|
|
async def test_stream_non_200_status(self):
|
|
"""流式请求非 200 状态应抛出 LLMProviderError"""
|
|
response = MagicMock()
|
|
response.status_code = 429
|
|
response.aread = AsyncMock(return_value=b'{"error":{"code":429,"message":"Rate limit exceeded"}}')
|
|
|
|
context = MagicMock()
|
|
context.__aenter__ = AsyncMock(return_value=response)
|
|
context.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.stream = MagicMock(return_value=context)
|
|
|
|
provider = GeminiProvider(api_key="test-key")
|
|
provider._client = mock_client
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gemini-2.0-flash",
|
|
)
|
|
|
|
with pytest.raises(LLMProviderError) as exc_info:
|
|
async for _ in provider.chat_stream(request):
|
|
pass
|
|
|
|
assert "429" in str(exc_info.value)
|
|
|
|
|
|
class TestGeminiErrors:
|
|
"""错误处理测试"""
|
|
|
|
def _make_mock_response(self, status_code: int, json_data: dict):
|
|
"""Create a mock httpx response."""
|
|
response = MagicMock(spec=httpx.Response)
|
|
response.status_code = status_code
|
|
response.json = MagicMock(return_value=json_data)
|
|
response.content = json.dumps(json_data).encode()
|
|
return response
|
|
|
|
async def test_400_bad_request(self):
|
|
"""400 错误应抛出 LLMProviderError"""
|
|
mock_response = self._make_mock_response(400, {
|
|
"error": {
|
|
"code": 400,
|
|
"message": "Invalid request",
|
|
},
|
|
})
|
|
|
|
mock_client = MagicMock(spec=httpx.AsyncClient)
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
|
|
provider = GeminiProvider(api_key="test-key")
|
|
provider._client = mock_client
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gemini-2.0-flash",
|
|
)
|
|
|
|
with pytest.raises(LLMProviderError) as exc_info:
|
|
await provider.chat(request)
|
|
|
|
assert "gemini" in str(exc_info.value)
|
|
assert "400" in str(exc_info.value)
|
|
|
|
async def test_403_api_key_invalid(self):
|
|
"""403 错误应抛出 LLMProviderError"""
|
|
mock_response = self._make_mock_response(403, {
|
|
"error": {
|
|
"code": 403,
|
|
"message": "API key not valid",
|
|
},
|
|
})
|
|
|
|
mock_client = MagicMock(spec=httpx.AsyncClient)
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
|
|
provider = GeminiProvider(api_key="bad-key")
|
|
provider._client = mock_client
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gemini-2.0-flash",
|
|
)
|
|
|
|
with pytest.raises(LLMProviderError) as exc_info:
|
|
await provider.chat(request)
|
|
|
|
assert "403" in str(exc_info.value)
|
|
|
|
async def test_429_rate_limit(self):
|
|
"""429 错误应抛出 LLMProviderError"""
|
|
mock_response = self._make_mock_response(429, {
|
|
"error": {
|
|
"code": 429,
|
|
"message": "Rate limit exceeded",
|
|
},
|
|
})
|
|
|
|
mock_client = MagicMock(spec=httpx.AsyncClient)
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
|
|
provider = GeminiProvider(api_key="test-key")
|
|
provider._client = mock_client
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gemini-2.0-flash",
|
|
)
|
|
|
|
with pytest.raises(LLMProviderError) as exc_info:
|
|
await provider.chat(request)
|
|
|
|
assert "429" in str(exc_info.value)
|
|
|
|
async def test_500_server_error(self):
|
|
"""500 错误应抛出 LLMProviderError"""
|
|
mock_response = self._make_mock_response(500, {
|
|
"error": {
|
|
"code": 500,
|
|
"message": "Internal server error",
|
|
},
|
|
})
|
|
|
|
mock_client = MagicMock(spec=httpx.AsyncClient)
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
|
|
provider = GeminiProvider(api_key="test-key")
|
|
provider._client = mock_client
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gemini-2.0-flash",
|
|
)
|
|
|
|
with pytest.raises(LLMProviderError):
|
|
await provider.chat(request)
|
|
|
|
async def test_503_service_unavailable(self):
|
|
"""503 错误应抛出 LLMProviderError"""
|
|
mock_response = self._make_mock_response(503, {
|
|
"error": {
|
|
"code": 503,
|
|
"message": "Service unavailable",
|
|
},
|
|
})
|
|
|
|
mock_client = MagicMock(spec=httpx.AsyncClient)
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
|
|
provider = GeminiProvider(api_key="test-key")
|
|
provider._client = mock_client
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gemini-2.0-flash",
|
|
)
|
|
|
|
with pytest.raises(LLMProviderError) as exc_info:
|
|
await provider.chat(request)
|
|
|
|
assert "503" in str(exc_info.value)
|
|
|
|
async def test_network_error(self):
|
|
"""网络错误应抛出 LLMProviderError"""
|
|
mock_client = MagicMock(spec=httpx.AsyncClient)
|
|
mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused"))
|
|
|
|
provider = GeminiProvider(api_key="test-key")
|
|
provider._client = mock_client
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gemini-2.0-flash",
|
|
)
|
|
|
|
with pytest.raises(LLMProviderError):
|
|
await provider.chat(request)
|
|
|
|
async def test_error_does_not_expose_api_key(self):
|
|
"""错误消息不应暴露 API Key"""
|
|
mock_response = self._make_mock_response(403, {
|
|
"error": {
|
|
"code": 403,
|
|
"message": "API key not valid",
|
|
},
|
|
})
|
|
|
|
mock_client = MagicMock(spec=httpx.AsyncClient)
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
|
|
provider = GeminiProvider(api_key="my-super-secret-key-12345")
|
|
provider._client = mock_client
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gemini-2.0-flash",
|
|
)
|
|
|
|
with pytest.raises(LLMProviderError) as exc_info:
|
|
await provider.chat(request)
|
|
|
|
assert "my-super-secret-key-12345" not in str(exc_info.value)
|
|
|
|
|
|
class TestGeminiGetModelInfo:
|
|
"""get_model_info() 测试"""
|
|
|
|
def test_returns_provider_and_model_info(self):
|
|
provider = GeminiProvider(
|
|
api_key="test-key",
|
|
model="gemini-2.0-flash",
|
|
max_output_tokens=8192,
|
|
)
|
|
info = provider.get_model_info()
|
|
|
|
assert info["provider"] == "gemini"
|
|
assert info["model"] == "gemini-2.0-flash"
|
|
assert info["max_output_tokens"] == 8192
|
|
|
|
def test_default_model_info(self):
|
|
provider = GeminiProvider(api_key="test-key")
|
|
info = provider.get_model_info()
|
|
|
|
assert info["provider"] == "gemini"
|
|
assert info["model"] == "gemini-2.0-flash"
|
|
assert info["max_output_tokens"] == 4096
|
|
|
|
|
|
class TestGeminiLazyClient:
|
|
"""Lazy client 初始化测试"""
|
|
|
|
def test_client_not_created_on_init(self):
|
|
"""初始化时不应创建 HTTP 客户端"""
|
|
provider = GeminiProvider(api_key="test-key")
|
|
assert provider._client is None
|
|
|
|
def test_client_created_on_first_use(self):
|
|
"""首次使用时应创建 HTTP 客户端"""
|
|
provider = GeminiProvider(api_key="test-key")
|
|
client = provider._get_client()
|
|
assert client is not None
|
|
assert provider._client is not None
|
|
|
|
def test_client_reused(self):
|
|
"""多次调用应复用同一客户端"""
|
|
provider = GeminiProvider(api_key="test-key")
|
|
client1 = provider._get_client()
|
|
client2 = provider._get_client()
|
|
assert client1 is client2
|
|
|
|
async def test_close_resets_client(self):
|
|
"""close 后客户端应被重置"""
|
|
provider = GeminiProvider(api_key="test-key")
|
|
_ = provider._get_client()
|
|
assert provider._client is not None
|
|
|
|
await provider.close()
|
|
assert provider._client is None
|