fischer-agentkit/tests/unit/test_gemini_provider.py

955 lines
33 KiB
Python

"""Gemini Provider 测试"""
import json
from unittest.mock import AsyncMock, MagicMock
import httpx
import pytest
from pytest_httpx import HTTPXMock
from agentkit.core.exceptions import LLMProviderError
from agentkit.llm.protocol import LLMRequest, LLMResponse, StreamChunk, TokenUsage
from agentkit.llm.providers.gemini import GeminiProvider
# Base URL for Gemini API (without key param - pytest-httpx matches without query)
_GEMINI_BASE = "https://generativelanguage.googleapis.com"
class TestGeminiMessageConversion:
"""消息格式转换测试"""
def setup_method(self):
self.provider = GeminiProvider(api_key="test-key")
def test_system_message_extracted_as_system_instruction(self):
"""system 消息应被提取为 systemInstruction"""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
]
system_instruction, contents = self.provider._convert_messages(messages)
assert system_instruction == {"parts": [{"text": "You are a helpful assistant."}]}
assert len(contents) == 1
assert contents[0]["role"] == "user"
assert contents[0]["parts"] == [{"text": "Hello"}]
def test_text_messages_converted_to_contents(self):
"""普通文本消息应转换为 Gemini contents"""
messages = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
{"role": "user", "content": "How are you?"},
]
system_instruction, contents = self.provider._convert_messages(messages)
assert system_instruction is None
assert len(contents) == 3
assert contents[0] == {"role": "user", "parts": [{"text": "Hi"}]}
assert contents[1] == {"role": "model", "parts": [{"text": "Hello!"}]}
assert contents[2] == {"role": "user", "parts": [{"text": "How are you?"}]}
def test_assistant_tool_calls_converted(self):
"""assistant 的 tool_calls 应转换为 functionCall parts"""
messages = [
{"role": "user", "content": "What's the weather?"},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Beijing"}',
},
}
],
},
]
_, contents = self.provider._convert_messages(messages)
assert len(contents) == 2
model_msg = contents[1]
assert model_msg["role"] == "model"
assert len(model_msg["parts"]) == 1
assert "functionCall" in model_msg["parts"][0]
assert model_msg["parts"][0]["functionCall"]["name"] == "get_weather"
assert model_msg["parts"][0]["functionCall"]["args"] == {"city": "Beijing"}
def test_assistant_tool_calls_with_text(self):
"""assistant 同时有文本和 tool_calls"""
messages = [
{
"role": "assistant",
"content": "Let me check that.",
"tool_calls": [
{
"id": "call_456",
"type": "function",
"function": {
"name": "search",
"arguments": '{"q": "test"}',
},
}
],
},
]
_, contents = self.provider._convert_messages(messages)
parts = contents[0]["parts"]
assert len(parts) == 2
assert parts[0] == {"text": "Let me check that."}
assert "functionCall" in parts[1]
def test_tool_result_converted_to_function_response(self):
"""tool 角色消息应转换为 functionResponse parts"""
messages = [
{
"role": "tool",
"tool_call_id": "call_123",
"name": "get_weather",
"content": "Sunny, 25°C",
},
]
_, contents = self.provider._convert_messages(messages)
assert len(contents) == 1
msg = contents[0]
assert msg["role"] == "user"
assert len(msg["parts"]) == 1
assert "functionResponse" in msg["parts"][0]
assert msg["parts"][0]["functionResponse"]["name"] == "get_weather"
assert msg["parts"][0]["functionResponse"]["response"]["content"] == "Sunny, 25°C"
def test_user_with_tool_call_id_converted(self):
"""user 消息带 tool_call_id 也应转换为 functionResponse"""
messages = [
{
"role": "user",
"tool_call_id": "call_789",
"content": "Result data",
},
]
_, contents = self.provider._convert_messages(messages)
msg = contents[0]
assert msg["role"] == "user"
assert "functionResponse" in msg["parts"][0]
def test_no_system_message(self):
"""没有 system 消息时返回 None"""
messages = [
{"role": "user", "content": "Hello"},
]
system_instruction, _ = self.provider._convert_messages(messages)
assert system_instruction is None
class TestGeminiToolConversion:
"""工具格式转换测试"""
def setup_method(self):
self.provider = GeminiProvider(api_key="test-key")
def test_convert_openai_tools_to_gemini(self):
"""OpenAI function 格式应转换为 Gemini functionDeclarations"""
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather for a city",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
},
},
}
]
result = self.provider._convert_tools(tools)
assert len(result) == 1
assert "functionDeclarations" in result[0]
declarations = result[0]["functionDeclarations"]
assert len(declarations) == 1
assert declarations[0]["name"] == "get_weather"
assert declarations[0]["description"] == "Get weather for a city"
assert declarations[0]["parameters"] == {
"type": "object",
"properties": {"city": {"type": "string"}},
}
def test_convert_empty_tools(self):
"""空工具列表应返回空列表"""
result = self.provider._convert_tools([])
assert result == []
def test_convert_tool_choice_auto(self):
"""tool_choice=auto 应转换为 Gemini AUTO 模式"""
result = self.provider._convert_tool_choice("auto")
assert result == {"functionCallingConfig": {"mode": "AUTO"}}
def test_convert_tool_choice_required(self):
"""tool_choice=required 应转换为 Gemini ANY 模式"""
result = self.provider._convert_tool_choice("required")
assert result == {"functionCallingConfig": {"mode": "ANY"}}
def test_convert_tool_choice_none(self):
"""tool_choice=none 应转换为 Gemini NONE 模式"""
result = self.provider._convert_tool_choice("none")
assert result == {"functionCallingConfig": {"mode": "NONE"}}
def test_convert_tool_choice_specific_tool(self):
"""指定工具名的 tool_choice 应转换为 Gemini AUTO 模式"""
result = self.provider._convert_tool_choice("get_weather")
assert result == {"functionCallingConfig": {"mode": "AUTO"}}
class TestGeminiResponseParsing:
"""响应解析测试"""
def setup_method(self):
self.provider = GeminiProvider(api_key="test-key")
def test_parse_text_response(self):
"""解析纯文本响应"""
data = {
"candidates": [
{
"content": {
"parts": [{"text": "Hello! How can I help?"}],
"role": "model",
},
"finishReason": "STOP",
}
],
"usageMetadata": {
"promptTokenCount": 10,
"candidatesTokenCount": 6,
"totalTokenCount": 16,
},
}
response = self.provider._parse_response(data, "gemini-2.0-flash")
assert isinstance(response, LLMResponse)
assert response.content == "Hello! How can I help?"
assert response.usage.prompt_tokens == 10
assert response.usage.completion_tokens == 6
assert not response.has_tool_calls
def test_parse_function_call_response(self):
"""解析包含 functionCall 的响应"""
data = {
"candidates": [
{
"content": {
"parts": [
{"text": "Let me check the weather."},
{
"functionCall": {
"name": "get_weather",
"args": {"city": "Beijing"},
}
},
],
"role": "model",
},
"finishReason": "STOP",
}
],
"usageMetadata": {
"promptTokenCount": 20,
"candidatesTokenCount": 15,
"totalTokenCount": 35,
},
}
response = self.provider._parse_response(data, "gemini-2.0-flash")
assert response.content == "Let me check the weather."
assert response.has_tool_calls
assert len(response.tool_calls) == 1
assert response.tool_calls[0].name == "get_weather"
assert response.tool_calls[0].arguments == {"city": "Beijing"}
def test_parse_multiple_function_calls(self):
"""解析包含多个 functionCall 的响应"""
data = {
"candidates": [
{
"content": {
"parts": [
{
"functionCall": {
"name": "get_weather",
"args": {"city": "Beijing"},
}
},
{
"functionCall": {
"name": "get_weather",
"args": {"city": "Shanghai"},
}
},
],
"role": "model",
},
"finishReason": "STOP",
}
],
"usageMetadata": {
"promptTokenCount": 25,
"candidatesTokenCount": 20,
"totalTokenCount": 45,
},
}
response = self.provider._parse_response(data, "gemini-2.0-flash")
assert len(response.tool_calls) == 2
assert response.tool_calls[0].name == "get_weather"
assert response.tool_calls[0].arguments == {"city": "Beijing"}
assert response.tool_calls[1].arguments == {"city": "Shanghai"}
def test_parse_empty_candidates(self):
"""解析空 candidates 响应"""
data = {
"candidates": [],
"usageMetadata": {
"promptTokenCount": 5,
"candidatesTokenCount": 0,
},
}
response = self.provider._parse_response(data, "gemini-2.0-flash")
assert response.content == ""
assert not response.has_tool_calls
def test_parse_model_version_in_response(self):
"""响应中的 modelVersion 应作为 model 返回"""
data = {
"candidates": [
{
"content": {
"parts": [{"text": "Hi"}],
"role": "model",
},
"finishReason": "STOP",
}
],
"modelVersion": "gemini-2.0-flash-001",
"usageMetadata": {
"promptTokenCount": 5,
"candidatesTokenCount": 2,
},
}
response = self.provider._parse_response(data, "gemini-2.0-flash")
assert response.model == "gemini-2.0-flash-001"
class TestGeminiChat:
"""chat() 方法集成测试 - 使用 mock client 而非 httpx_mock"""
def _make_mock_response(self, status_code: int, json_data: dict):
"""Create a mock httpx response."""
response = MagicMock(spec=httpx.Response)
response.status_code = status_code
response.json = MagicMock(return_value=json_data)
response.content = json.dumps(json_data).encode()
return response
async def test_chat_returns_llm_response(self):
"""chat 应返回 LLMResponse"""
mock_response = self._make_mock_response(200, {
"candidates": [
{
"content": {
"parts": [{"text": "Hello from Gemini!"}],
"role": "model",
},
"finishReason": "STOP",
}
],
"usageMetadata": {
"promptTokenCount": 10,
"candidatesTokenCount": 5,
"totalTokenCount": 15,
},
})
mock_client = MagicMock(spec=httpx.AsyncClient)
mock_client.post = AsyncMock(return_value=mock_response)
provider = GeminiProvider(api_key="test-key")
provider._client = mock_client
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gemini-2.0-flash",
)
response = await provider.chat(request)
assert isinstance(response, LLMResponse)
assert response.content == "Hello from Gemini!"
assert response.usage.prompt_tokens == 10
assert response.usage.completion_tokens == 5
assert response.latency_ms > 0
async def test_chat_with_system_message(self):
"""system 消息应作为 systemInstruction 发送"""
mock_response = self._make_mock_response(200, {
"candidates": [
{
"content": {
"parts": [{"text": "I am a helpful assistant."}],
"role": "model",
},
"finishReason": "STOP",
}
],
"usageMetadata": {
"promptTokenCount": 15,
"candidatesTokenCount": 8,
},
})
mock_client = MagicMock(spec=httpx.AsyncClient)
mock_client.post = AsyncMock(return_value=mock_response)
provider = GeminiProvider(api_key="test-key")
provider._client = mock_client
request = LLMRequest(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who are you?"},
],
model="gemini-2.0-flash",
)
response = await provider.chat(request)
assert response.content == "I am a helpful assistant."
# Verify the request payload
call_args = mock_client.post.call_args
request_body = call_args.kwargs.get("json", call_args[1].get("json", {}))
assert "systemInstruction" in request_body
assert request_body["systemInstruction"]["parts"][0]["text"] == "You are a helpful assistant."
# System should NOT be in contents
for msg in request_body["contents"]:
assert msg["role"] != "system"
async def test_chat_with_tools(self):
"""带工具的请求应正确转换格式"""
mock_response = self._make_mock_response(200, {
"candidates": [
{
"content": {
"parts": [
{
"functionCall": {
"name": "get_weather",
"args": {"city": "Tokyo"},
}
}
],
"role": "model",
},
"finishReason": "STOP",
}
],
"usageMetadata": {
"promptTokenCount": 30,
"candidatesTokenCount": 20,
},
})
mock_client = MagicMock(spec=httpx.AsyncClient)
mock_client.post = AsyncMock(return_value=mock_response)
provider = GeminiProvider(api_key="test-key")
provider._client = mock_client
request = LLMRequest(
messages=[{"role": "user", "content": "Weather in Tokyo?"}],
model="gemini-2.0-flash",
tools=[
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
},
},
}
],
)
response = await provider.chat(request)
assert response.has_tool_calls
assert response.tool_calls[0].name == "get_weather"
assert response.tool_calls[0].arguments == {"city": "Tokyo"}
# Verify request format
call_args = mock_client.post.call_args
request_body = call_args.kwargs.get("json", call_args[1].get("json", {}))
assert "tools" in request_body
assert "functionDeclarations" in request_body["tools"][0]
assert request_body["tools"][0]["functionDeclarations"][0]["name"] == "get_weather"
assert "toolConfig" in request_body
assert request_body["toolConfig"]["functionCallingConfig"]["mode"] == "AUTO"
async def test_chat_api_key_in_url(self):
"""API key 应通过 URL 参数传递"""
mock_response = self._make_mock_response(200, {
"candidates": [
{
"content": {
"parts": [{"text": "OK"}],
"role": "model",
},
"finishReason": "STOP",
}
],
"usageMetadata": {
"promptTokenCount": 5,
"candidatesTokenCount": 2,
},
})
mock_client = MagicMock(spec=httpx.AsyncClient)
mock_client.post = AsyncMock(return_value=mock_response)
provider = GeminiProvider(api_key="my-secret-key")
provider._client = mock_client
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gemini-2.0-flash",
)
await provider.chat(request)
call_args = mock_client.post.call_args
url = call_args[0][0] if call_args[0] else call_args.kwargs.get("url", "")
assert "key=my-secret-key" in url
async def test_chat_with_custom_base_url(self):
"""自定义 base_url 应正确使用"""
mock_response = self._make_mock_response(200, {
"candidates": [
{
"content": {
"parts": [{"text": "Proxy response"}],
"role": "model",
},
"finishReason": "STOP",
}
],
"usageMetadata": {
"promptTokenCount": 5,
"candidatesTokenCount": 3,
},
})
mock_client = MagicMock(spec=httpx.AsyncClient)
mock_client.post = AsyncMock(return_value=mock_response)
provider = GeminiProvider(
api_key="test-key",
base_url="https://custom-proxy.example.com",
)
provider._client = mock_client
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gemini-2.0-flash",
)
response = await provider.chat(request)
assert response.content == "Proxy response"
call_args = mock_client.post.call_args
url = call_args[0][0] if call_args[0] else call_args.kwargs.get("url", "")
assert "custom-proxy.example.com" in url
class TestGeminiStreaming:
"""chat_stream() 方法测试"""
def _make_stream_response(self, sse_lines: list[str]):
"""Create a mock httpx streaming response context manager."""
response = MagicMock()
response.status_code = 200
async def aiter_lines():
for line in sse_lines:
yield line
response.aiter_lines = aiter_lines
response.aread = AsyncMock(return_value=b"")
context = MagicMock()
context.__aenter__ = AsyncMock(return_value=response)
context.__aexit__ = AsyncMock(return_value=False)
return context
async def test_stream_text_response(self):
"""流式文本响应应正确解析"""
sse_lines = [
'data: {"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3,"totalTokenCount":8}}',
'',
'data: {"candidates":[{"content":{"parts":[{"text":" world"}],"role":"model"},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":5,"totalTokenCount":10}}',
'',
]
mock_client = MagicMock()
mock_client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
provider = GeminiProvider(api_key="test-key")
provider._client = mock_client
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gemini-2.0-flash",
)
chunks = []
async for chunk in provider.chat_stream(request):
chunks.append(chunk)
text_chunks = [c for c in chunks if c.content]
assert len(text_chunks) == 2
assert text_chunks[0].content == "Hello"
assert text_chunks[1].content == " world"
async def test_stream_function_call_response(self):
"""流式 functionCall 响应应正确解析"""
sse_lines = [
'data: {"candidates":[{"content":{"parts":[{"functionCall":{"name":"get_weather","args":{"city":"Paris"}}}],"role":"model"},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":20,"candidatesTokenCount":15}}',
'',
]
mock_client = MagicMock()
mock_client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
provider = GeminiProvider(api_key="test-key")
provider._client = mock_client
request = LLMRequest(
messages=[{"role": "user", "content": "Weather in Paris?"}],
model="gemini-2.0-flash",
tools=[
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
},
}
],
)
chunks = []
async for chunk in provider.chat_stream(request):
chunks.append(chunk)
final_chunks = [c for c in chunks if c.is_final]
assert len(final_chunks) == 1
assert len(final_chunks[0].tool_calls) == 1
assert final_chunks[0].tool_calls[0].name == "get_weather"
assert final_chunks[0].tool_calls[0].arguments == {"city": "Paris"}
async def test_stream_with_usage_metadata(self):
"""流式响应应包含 usage 信息"""
sse_lines = [
'data: {"candidates":[{"content":{"parts":[{"text":"Hi"}],"role":"model"},"finishReason":"STOP"}]}',
'',
'data: {"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}',
'',
]
mock_client = MagicMock()
mock_client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
provider = GeminiProvider(api_key="test-key")
provider._client = mock_client
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gemini-2.0-flash",
)
chunks = []
async for chunk in provider.chat_stream(request):
chunks.append(chunk)
final_chunks = [c for c in chunks if c.is_final]
assert len(final_chunks) == 1
assert final_chunks[0].usage is not None
assert final_chunks[0].usage.prompt_tokens == 10
assert final_chunks[0].usage.completion_tokens == 5
async def test_stream_non_200_status(self):
"""流式请求非 200 状态应抛出 LLMProviderError"""
response = MagicMock()
response.status_code = 429
response.aread = AsyncMock(return_value=b'{"error":{"code":429,"message":"Rate limit exceeded"}}')
context = MagicMock()
context.__aenter__ = AsyncMock(return_value=response)
context.__aexit__ = AsyncMock(return_value=False)
mock_client = MagicMock()
mock_client.stream = MagicMock(return_value=context)
provider = GeminiProvider(api_key="test-key")
provider._client = mock_client
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gemini-2.0-flash",
)
with pytest.raises(LLMProviderError) as exc_info:
async for _ in provider.chat_stream(request):
pass
assert "429" in str(exc_info.value)
class TestGeminiErrors:
"""错误处理测试"""
def _make_mock_response(self, status_code: int, json_data: dict):
"""Create a mock httpx response."""
response = MagicMock(spec=httpx.Response)
response.status_code = status_code
response.json = MagicMock(return_value=json_data)
response.content = json.dumps(json_data).encode()
return response
async def test_400_bad_request(self):
"""400 错误应抛出 LLMProviderError"""
mock_response = self._make_mock_response(400, {
"error": {
"code": 400,
"message": "Invalid request",
},
})
mock_client = MagicMock(spec=httpx.AsyncClient)
mock_client.post = AsyncMock(return_value=mock_response)
provider = GeminiProvider(api_key="test-key")
provider._client = mock_client
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gemini-2.0-flash",
)
with pytest.raises(LLMProviderError) as exc_info:
await provider.chat(request)
assert "gemini" in str(exc_info.value)
assert "400" in str(exc_info.value)
async def test_403_api_key_invalid(self):
"""403 错误应抛出 LLMProviderError"""
mock_response = self._make_mock_response(403, {
"error": {
"code": 403,
"message": "API key not valid",
},
})
mock_client = MagicMock(spec=httpx.AsyncClient)
mock_client.post = AsyncMock(return_value=mock_response)
provider = GeminiProvider(api_key="bad-key")
provider._client = mock_client
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gemini-2.0-flash",
)
with pytest.raises(LLMProviderError) as exc_info:
await provider.chat(request)
assert "403" in str(exc_info.value)
async def test_429_rate_limit(self):
"""429 错误应抛出 LLMProviderError"""
mock_response = self._make_mock_response(429, {
"error": {
"code": 429,
"message": "Rate limit exceeded",
},
})
mock_client = MagicMock(spec=httpx.AsyncClient)
mock_client.post = AsyncMock(return_value=mock_response)
provider = GeminiProvider(api_key="test-key")
provider._client = mock_client
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gemini-2.0-flash",
)
with pytest.raises(LLMProviderError) as exc_info:
await provider.chat(request)
assert "429" in str(exc_info.value)
async def test_500_server_error(self):
"""500 错误应抛出 LLMProviderError"""
mock_response = self._make_mock_response(500, {
"error": {
"code": 500,
"message": "Internal server error",
},
})
mock_client = MagicMock(spec=httpx.AsyncClient)
mock_client.post = AsyncMock(return_value=mock_response)
provider = GeminiProvider(api_key="test-key")
provider._client = mock_client
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gemini-2.0-flash",
)
with pytest.raises(LLMProviderError):
await provider.chat(request)
async def test_503_service_unavailable(self):
"""503 错误应抛出 LLMProviderError"""
mock_response = self._make_mock_response(503, {
"error": {
"code": 503,
"message": "Service unavailable",
},
})
mock_client = MagicMock(spec=httpx.AsyncClient)
mock_client.post = AsyncMock(return_value=mock_response)
provider = GeminiProvider(api_key="test-key")
provider._client = mock_client
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gemini-2.0-flash",
)
with pytest.raises(LLMProviderError) as exc_info:
await provider.chat(request)
assert "503" in str(exc_info.value)
async def test_network_error(self):
"""网络错误应抛出 LLMProviderError"""
mock_client = MagicMock(spec=httpx.AsyncClient)
mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused"))
provider = GeminiProvider(api_key="test-key")
provider._client = mock_client
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gemini-2.0-flash",
)
with pytest.raises(LLMProviderError):
await provider.chat(request)
async def test_error_does_not_expose_api_key(self):
"""错误消息不应暴露 API Key"""
mock_response = self._make_mock_response(403, {
"error": {
"code": 403,
"message": "API key not valid",
},
})
mock_client = MagicMock(spec=httpx.AsyncClient)
mock_client.post = AsyncMock(return_value=mock_response)
provider = GeminiProvider(api_key="my-super-secret-key-12345")
provider._client = mock_client
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gemini-2.0-flash",
)
with pytest.raises(LLMProviderError) as exc_info:
await provider.chat(request)
assert "my-super-secret-key-12345" not in str(exc_info.value)
class TestGeminiGetModelInfo:
"""get_model_info() 测试"""
def test_returns_provider_and_model_info(self):
provider = GeminiProvider(
api_key="test-key",
model="gemini-2.0-flash",
max_output_tokens=8192,
)
info = provider.get_model_info()
assert info["provider"] == "gemini"
assert info["model"] == "gemini-2.0-flash"
assert info["max_output_tokens"] == 8192
def test_default_model_info(self):
provider = GeminiProvider(api_key="test-key")
info = provider.get_model_info()
assert info["provider"] == "gemini"
assert info["model"] == "gemini-2.0-flash"
assert info["max_output_tokens"] == 4096
class TestGeminiLazyClient:
"""Lazy client 初始化测试"""
def test_client_not_created_on_init(self):
"""初始化时不应创建 HTTP 客户端"""
provider = GeminiProvider(api_key="test-key")
assert provider._client is None
def test_client_created_on_first_use(self):
"""首次使用时应创建 HTTP 客户端"""
provider = GeminiProvider(api_key="test-key")
client = provider._get_client()
assert client is not None
assert provider._client is not None
def test_client_reused(self):
"""多次调用应复用同一客户端"""
provider = GeminiProvider(api_key="test-key")
client1 = provider._get_client()
client2 = provider._get_client()
assert client1 is client2
async def test_close_resets_client(self):
"""close 后客户端应被重置"""
provider = GeminiProvider(api_key="test-key")
_ = provider._get_client()
assert provider._client is not None
await provider.close()
assert provider._client is None