fischer-agentkit/tests/unit/llm/test_litellm_provider.py

544 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U15 — LitellmProvider 单元测试。
用 ``unittest.mock.AsyncMock`` mock ``litellm.acompletion``,覆盖:
- 6 个 provider 类型openai/anthropic/gemini/doubao/wenxin/yuanbao经 LiteLLM 调用
- 未知 provider_type 回退到 openai/ 前缀
- 自定义 base_url 透传
- tools 透传 + tool_calls 响应解析
- 流式 chunk 迭代 + 终止 chunk
- 流式空消息处理
- LiteLLM 异常 → LLMProviderError
- latency_ms 非负
"""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
from agentkit.core.exceptions import LLMProviderError
from agentkit.llm.protocol import LLMRequest, LLMResponse
from agentkit.llm.providers.litellm_provider import (
LitellmProvider,
_model_prefix_for,
create_litellm_provider,
)
# ----------------------------------------------------------------------
# 测试辅助:构造 OpenAI 格式的 fake response / chunk
# ----------------------------------------------------------------------
def _fake_response(
content: str = "Hello!",
model: str = "gpt-4o-mini",
prompt_tokens: int = 10,
completion_tokens: int = 5,
tool_calls: list | None = None,
) -> SimpleNamespace:
"""构造非流式 fake responseOpenAI ChatCompletion 格式)。"""
message = SimpleNamespace(content=content, tool_calls=tool_calls)
return SimpleNamespace(
choices=[SimpleNamespace(message=message)],
usage=SimpleNamespace(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
),
model=model,
)
def _fake_tool_call(
tc_id: str = "call_1",
name: str = "get_weather",
arguments: str = '{"city": "Beijing"}',
) -> SimpleNamespace:
"""构造非流式 tool_call 对象。"""
return SimpleNamespace(
id=tc_id,
function=SimpleNamespace(name=name, arguments=arguments),
)
def _fake_stream_chunk(
content: str = "",
model: str = "gpt-4o-mini",
tool_calls_delta: list | None = None,
usage: SimpleNamespace | None = None,
) -> SimpleNamespace:
"""构造流式 chunkOpenAI ChatCompletionChunk 格式)。"""
delta = SimpleNamespace(content=content, tool_calls=tool_calls_delta)
return SimpleNamespace(
choices=[SimpleNamespace(delta=delta)],
model=model,
usage=usage,
)
def _fake_stream_tool_call_delta(
index: int = 0,
tc_id: str | None = "call_1",
name: str | None = "get_weather",
arguments_fragment: str | None = None,
) -> SimpleNamespace:
"""构造流式 tool_call delta 片段。"""
return SimpleNamespace(
index=index,
id=tc_id,
function=SimpleNamespace(name=name, arguments=arguments_fragment),
)
# ----------------------------------------------------------------------
# 1-7: provider_type → model_prefix 映射 + 基本 chat
# ----------------------------------------------------------------------
@pytest.mark.parametrize(
"provider_type, expected_prefix",
[
("openai", "openai/"),
("anthropic", "anthropic/"),
("gemini", "gemini/"),
("doubao", "volcengine/"),
("wenxin", "openai/"), # ponytail: wenxin 回退到 openai/
("yuanbao", "hunyuan/"),
("unknown_type", "openai/"), # 未知 → openai/ 回退
],
ids=["openai", "anthropic", "gemini", "doubao", "wenxin", "yuanbao", "unknown"],
)
def test_model_prefix_mapping(provider_type: str, expected_prefix: str):
"""验证 provider_type → LiteLLM model 前缀映射(含未知类型回退)。"""
assert _model_prefix_for(provider_type) == expected_prefix
@pytest.mark.parametrize(
"provider_type, expected_prefix, model_name",
[
("openai", "openai/", "gpt-4o-mini"),
("anthropic", "anthropic/", "claude-sonnet-4-20250514"),
("gemini", "gemini/", "gemini-2.0-flash"),
("doubao", "volcengine/", "doubao-pro-32k"),
("wenxin", "openai/", "ernie-4.5-turbo-128k"),
("yuanbao", "hunyuan/", "hunyuan-pro"),
],
ids=["openai", "anthropic", "gemini", "doubao", "wenxin", "yuanbao"],
)
async def test_chat_via_litellm_for_each_provider_type(
provider_type: str,
expected_prefix: str,
model_name: str,
):
"""6 个 provider 类型经 LiteLLM 调用 — 验证 model 前缀和响应翻译。"""
provider = create_litellm_provider(provider_type, api_key="sk-test")
assert provider._model_prefix == expected_prefix
fake_resp = _fake_response(content="ok", model=model_name)
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
response = await provider.chat(
LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model=model_name,
)
)
assert isinstance(response, LLMResponse)
assert response.content == "ok"
assert response.model == model_name
assert response.usage.prompt_tokens == 10
assert response.usage.completion_tokens == 5
# ----------------------------------------------------------------------
# 8: 自定义 base_url 透传
# ----------------------------------------------------------------------
async def test_custom_base_url_passed_through():
"""自定义 base_url 应作为 api_base 传给 litellm.acompletion。"""
provider = create_litellm_provider(
"openai",
api_key="sk-test",
base_url="https://custom.api/v1",
)
mock_acompletion = AsyncMock(return_value=_fake_response())
with patch("litellm.acompletion", new=mock_acompletion):
await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o-mini")
)
_, kwargs = mock_acompletion.call_args
assert kwargs["api_base"] == "https://custom.api/v1"
assert kwargs["api_key"] == "sk-test"
assert kwargs["model"] == "openai/gpt-4o-mini"
# ----------------------------------------------------------------------
# 9: tools 透传
# ----------------------------------------------------------------------
async def test_tools_passed_through():
"""request.tools 应透传给 litellm.acompletion。"""
provider = create_litellm_provider("openai", api_key="sk-test")
mock_acompletion = AsyncMock(return_value=_fake_response())
with patch("litellm.acompletion", new=mock_acompletion):
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {}},
},
}
]
await provider.chat(
LLMRequest(
messages=[{"role": "user", "content": "weather?"}],
model="gpt-4o",
tools=tools,
tool_choice="auto",
)
)
_, kwargs = mock_acompletion.call_args
assert kwargs["tools"] == tools
assert kwargs["tool_choice"] == "auto"
# ----------------------------------------------------------------------
# 10: tool_calls 响应解析
# ----------------------------------------------------------------------
async def test_tool_calls_in_response_parsed():
"""非流式响应中的 tool_calls 应解析成 LLMResponse.tool_calls。"""
provider = create_litellm_provider("openai", api_key="sk-test")
fake_resp = _fake_response(
content="",
tool_calls=[
_fake_tool_call(
tc_id="call_abc",
name="get_weather",
arguments='{"city": "Beijing"}',
)
],
)
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
response = await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
)
assert response.has_tool_calls
assert len(response.tool_calls) == 1
tc = response.tool_calls[0]
assert tc.id == "call_abc"
assert tc.name == "get_weather"
assert tc.arguments == {"city": "Beijing"}
async def test_tool_calls_with_dict_arguments():
"""tool_call.arguments 为 dict 时直接采用。"""
provider = create_litellm_provider("openai", api_key="sk-test")
fake_resp = _fake_response(
tool_calls=[
SimpleNamespace(
id="call_1",
function=SimpleNamespace(name="fn", arguments={"k": "v"}),
)
]
)
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
response = await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
)
assert response.tool_calls[0].arguments == {"k": "v"}
# ----------------------------------------------------------------------
# 11: 流式 yields chunks
# ----------------------------------------------------------------------
async def test_streaming_yields_chunks():
"""流式应 yield 多个 is_final=False chunk + 一个 is_final=True 终止 chunk。"""
provider = create_litellm_provider("openai", api_key="sk-test")
chunks = [
_fake_stream_chunk(content="Hello"),
_fake_stream_chunk(content=" world"),
_fake_stream_chunk(
usage=SimpleNamespace(prompt_tokens=8, completion_tokens=2),
),
]
async def _fake_stream(**_kwargs):
for c in chunks:
yield c
with patch("litellm.acompletion", new=_fake_stream):
results = []
async for chunk in provider.chat_stream(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o-mini")
):
results.append(chunk)
# 前 3 个是流式 chunkis_final=False最后一个是终止 chunkis_final=True
assert len(results) == 4
assert results[0].content == "Hello"
assert results[1].content == " world"
assert all(not r.is_final for r in results[:3])
assert results[-1].is_final is True
# 终止 chunk 含聚合 usage
assert results[-1].usage is not None
assert results[-1].usage.prompt_tokens == 8
assert results[-1].usage.completion_tokens == 2
async def test_streaming_aggregates_tool_calls():
"""流式 tool_calls 片段应聚合到终止 chunk。"""
provider = create_litellm_provider("openai", api_key="sk-test")
chunks = [
_fake_stream_chunk(
tool_calls_delta=[
_fake_stream_tool_call_delta(
index=0,
tc_id="call_1",
name="get_weather",
arguments_fragment='{"city":',
)
]
),
_fake_stream_chunk(
tool_calls_delta=[
_fake_stream_tool_call_delta(
index=0,
tc_id=None,
name=None,
arguments_fragment=' "Beijing"}',
)
]
),
]
async def _fake_stream(**_kwargs):
for c in chunks:
yield c
with patch("litellm.acompletion", new=_fake_stream):
results = []
async for chunk in provider.chat_stream(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
):
results.append(chunk)
final = results[-1]
assert final.is_final
assert len(final.tool_calls) == 1
assert final.tool_calls[0].id == "call_1"
assert final.tool_calls[0].name == "get_weather"
assert final.tool_calls[0].arguments == {"city": "Beijing"}
# ----------------------------------------------------------------------
# 12: 流式空消息处理
# ----------------------------------------------------------------------
async def test_streaming_empty_messages_handled():
"""流式无 chunk 时仍应 yield 一个 is_final=True 空终止 chunk不崩溃。"""
provider = create_litellm_provider("openai", api_key="sk-test")
async def _empty_stream(**_kwargs):
return
yield # 让函数成为 async generator永不执行
with patch("litellm.acompletion", new=_empty_stream):
results = []
async for chunk in provider.chat_stream(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
):
results.append(chunk)
# 至少 yield 一个终止 chunk
assert len(results) >= 1
assert results[-1].is_final is True
assert results[-1].content == ""
# ----------------------------------------------------------------------
# 13: LiteLLM 异常 → LLMProviderError
# ----------------------------------------------------------------------
async def test_litellm_exception_raises_provider_error():
"""litellm.acompletion 抛异常时应包装成 LLMProviderError。"""
provider = create_litellm_provider("openai", api_key="sk-test")
with patch("litellm.acompletion", new=AsyncMock(side_effect=RuntimeError("boom"))):
with pytest.raises(LLMProviderError) as exc_info:
await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
)
assert "boom" in str(exc_info.value)
async def test_litellm_stream_exception_raises_provider_error():
"""流式 litellm.acompletion 抛异常时应包装成 LLMProviderError。"""
provider = create_litellm_provider("openai", api_key="sk-test")
async def _failing_stream(**_kwargs):
raise RuntimeError("stream boom")
yield # unreachable仅为 async generator 语法
with patch("litellm.acompletion", new=_failing_stream):
with pytest.raises(LLMProviderError) as exc_info:
async for _ in provider.chat_stream(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
):
pass
assert "stream boom" in str(exc_info.value)
# ----------------------------------------------------------------------
# 14: latency 非负
# ----------------------------------------------------------------------
async def test_latency_measured_non_negative():
"""LLMResponse.latency_ms 应为非负数。"""
provider = create_litellm_provider("openai", api_key="sk-test")
with patch("litellm.acompletion", new=AsyncMock(return_value=_fake_response())):
response = await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
)
assert response.latency_ms >= 0
# ----------------------------------------------------------------------
# 额外num_retries=0 禁用 LiteLLM 自带 retry
# ----------------------------------------------------------------------
async def test_litellm_num_retries_disabled():
"""应传 num_retries=0 禁用 LiteLLM 自带 retry由 gateway fallback 负责)。"""
provider = create_litellm_provider("openai", api_key="sk-test")
mock_acompletion = AsyncMock(return_value=_fake_response())
with patch("litellm.acompletion", new=mock_acompletion):
await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
)
_, kwargs = mock_acompletion.call_args
assert kwargs["num_retries"] == 0
assert kwargs["stream"] is False
async def test_streaming_passes_stream_true():
"""chat_stream 应传 stream=True。"""
provider = create_litellm_provider("openai", api_key="sk-test")
captured: dict = {}
async def _capturing_stream(**kwargs):
captured.update(kwargs)
async def _inner():
return
yield
return _inner()
with patch("litellm.acompletion", new=_capturing_stream):
async for _ in provider.chat_stream(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
):
pass
assert captured["stream"] is True
# ----------------------------------------------------------------------
# 额外timeout 透传
# ----------------------------------------------------------------------
async def test_timeout_passed_through():
"""request.timeout 应透传给 litellm.acompletion。"""
provider = create_litellm_provider("openai", api_key="sk-test")
mock_acompletion = AsyncMock(return_value=_fake_response())
with patch("litellm.acompletion", new=mock_acompletion):
await provider.chat(
LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gpt-4o",
timeout=42.0,
)
)
_, kwargs = mock_acompletion.call_args
assert kwargs["timeout"] == 42.0
# ----------------------------------------------------------------------
# 额外LitellmProvider 直接构造(不经工厂)
# ----------------------------------------------------------------------
async def test_direct_litellm_provider_construction():
"""直接用 LitellmProvider(model_prefix=...) 构造也应工作。"""
provider = LitellmProvider(
model_prefix="anthropic/",
api_key="sk-test",
provider_type="anthropic",
)
fake_resp = _fake_response(content="hi", model="claude-3")
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
response = await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="claude-3")
)
assert response.content == "hi"
# 验证 model 前缀正确拼接
mock_acompletion = AsyncMock(return_value=fake_resp)
with patch("litellm.acompletion", new=mock_acompletion):
await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="claude-3")
)
_, kwargs = mock_acompletion.call_args
assert kwargs["model"] == "anthropic/claude-3"
# ----------------------------------------------------------------------
# 额外JSON tool_calls 解析容错
# ----------------------------------------------------------------------
async def test_tool_calls_invalid_json_arguments_wrapped():
"""tool_call.arguments 为非法 JSON 时应包装成 {"raw": ...} 不崩溃。"""
provider = create_litellm_provider("openai", api_key="sk-test")
fake_resp = _fake_response(
tool_calls=[
_fake_tool_call(
tc_id="call_1",
name="fn",
arguments="not-valid-json{",
)
]
)
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
response = await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
)
assert response.tool_calls[0].arguments == {"raw": "not-valid-json{"}