544 lines
18 KiB
Python
544 lines
18 KiB
Python
"""U15 — LitellmProvider 单元测试。
|
||
|
||
用 ``unittest.mock.AsyncMock`` mock ``litellm.acompletion``,覆盖:
|
||
- 6 个 provider 类型(openai/anthropic/gemini/doubao/wenxin/yuanbao)经 LiteLLM 调用
|
||
- 未知 provider_type 回退到 openai/ 前缀
|
||
- 自定义 base_url 透传
|
||
- tools 透传 + tool_calls 响应解析
|
||
- 流式 chunk 迭代 + 终止 chunk
|
||
- 流式空消息处理
|
||
- LiteLLM 异常 → LLMProviderError
|
||
- latency_ms 非负
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from types import SimpleNamespace
|
||
from unittest.mock import AsyncMock, patch
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.exceptions import LLMProviderError
|
||
from agentkit.llm.protocol import LLMRequest, LLMResponse
|
||
from agentkit.llm.providers.litellm_provider import (
|
||
LitellmProvider,
|
||
_model_prefix_for,
|
||
create_litellm_provider,
|
||
)
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 测试辅助:构造 OpenAI 格式的 fake response / chunk
|
||
# ----------------------------------------------------------------------
|
||
|
||
|
||
def _fake_response(
|
||
content: str = "Hello!",
|
||
model: str = "gpt-4o-mini",
|
||
prompt_tokens: int = 10,
|
||
completion_tokens: int = 5,
|
||
tool_calls: list | None = None,
|
||
) -> SimpleNamespace:
|
||
"""构造非流式 fake response(OpenAI ChatCompletion 格式)。"""
|
||
message = SimpleNamespace(content=content, tool_calls=tool_calls)
|
||
return SimpleNamespace(
|
||
choices=[SimpleNamespace(message=message)],
|
||
usage=SimpleNamespace(
|
||
prompt_tokens=prompt_tokens,
|
||
completion_tokens=completion_tokens,
|
||
),
|
||
model=model,
|
||
)
|
||
|
||
|
||
def _fake_tool_call(
|
||
tc_id: str = "call_1",
|
||
name: str = "get_weather",
|
||
arguments: str = '{"city": "Beijing"}',
|
||
) -> SimpleNamespace:
|
||
"""构造非流式 tool_call 对象。"""
|
||
return SimpleNamespace(
|
||
id=tc_id,
|
||
function=SimpleNamespace(name=name, arguments=arguments),
|
||
)
|
||
|
||
|
||
def _fake_stream_chunk(
|
||
content: str = "",
|
||
model: str = "gpt-4o-mini",
|
||
tool_calls_delta: list | None = None,
|
||
usage: SimpleNamespace | None = None,
|
||
) -> SimpleNamespace:
|
||
"""构造流式 chunk(OpenAI ChatCompletionChunk 格式)。"""
|
||
delta = SimpleNamespace(content=content, tool_calls=tool_calls_delta)
|
||
return SimpleNamespace(
|
||
choices=[SimpleNamespace(delta=delta)],
|
||
model=model,
|
||
usage=usage,
|
||
)
|
||
|
||
|
||
def _fake_stream_tool_call_delta(
|
||
index: int = 0,
|
||
tc_id: str | None = "call_1",
|
||
name: str | None = "get_weather",
|
||
arguments_fragment: str | None = None,
|
||
) -> SimpleNamespace:
|
||
"""构造流式 tool_call delta 片段。"""
|
||
return SimpleNamespace(
|
||
index=index,
|
||
id=tc_id,
|
||
function=SimpleNamespace(name=name, arguments=arguments_fragment),
|
||
)
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 1-7: provider_type → model_prefix 映射 + 基本 chat
|
||
# ----------------------------------------------------------------------
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"provider_type, expected_prefix",
|
||
[
|
||
("openai", "openai/"),
|
||
("anthropic", "anthropic/"),
|
||
("gemini", "gemini/"),
|
||
("doubao", "volcengine/"),
|
||
("wenxin", "openai/"), # ponytail: wenxin 回退到 openai/
|
||
("yuanbao", "hunyuan/"),
|
||
("unknown_type", "openai/"), # 未知 → openai/ 回退
|
||
],
|
||
ids=["openai", "anthropic", "gemini", "doubao", "wenxin", "yuanbao", "unknown"],
|
||
)
|
||
def test_model_prefix_mapping(provider_type: str, expected_prefix: str):
|
||
"""验证 provider_type → LiteLLM model 前缀映射(含未知类型回退)。"""
|
||
assert _model_prefix_for(provider_type) == expected_prefix
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"provider_type, expected_prefix, model_name",
|
||
[
|
||
("openai", "openai/", "gpt-4o-mini"),
|
||
("anthropic", "anthropic/", "claude-sonnet-4-20250514"),
|
||
("gemini", "gemini/", "gemini-2.0-flash"),
|
||
("doubao", "volcengine/", "doubao-pro-32k"),
|
||
("wenxin", "openai/", "ernie-4.5-turbo-128k"),
|
||
("yuanbao", "hunyuan/", "hunyuan-pro"),
|
||
],
|
||
ids=["openai", "anthropic", "gemini", "doubao", "wenxin", "yuanbao"],
|
||
)
|
||
async def test_chat_via_litellm_for_each_provider_type(
|
||
provider_type: str,
|
||
expected_prefix: str,
|
||
model_name: str,
|
||
):
|
||
"""6 个 provider 类型经 LiteLLM 调用 — 验证 model 前缀和响应翻译。"""
|
||
provider = create_litellm_provider(provider_type, api_key="sk-test")
|
||
assert provider._model_prefix == expected_prefix
|
||
|
||
fake_resp = _fake_response(content="ok", model=model_name)
|
||
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
|
||
response = await provider.chat(
|
||
LLMRequest(
|
||
messages=[{"role": "user", "content": "Hi"}],
|
||
model=model_name,
|
||
)
|
||
)
|
||
|
||
assert isinstance(response, LLMResponse)
|
||
assert response.content == "ok"
|
||
assert response.model == model_name
|
||
assert response.usage.prompt_tokens == 10
|
||
assert response.usage.completion_tokens == 5
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 8: 自定义 base_url 透传
|
||
# ----------------------------------------------------------------------
|
||
|
||
|
||
async def test_custom_base_url_passed_through():
|
||
"""自定义 base_url 应作为 api_base 传给 litellm.acompletion。"""
|
||
provider = create_litellm_provider(
|
||
"openai",
|
||
api_key="sk-test",
|
||
base_url="https://custom.api/v1",
|
||
)
|
||
mock_acompletion = AsyncMock(return_value=_fake_response())
|
||
with patch("litellm.acompletion", new=mock_acompletion):
|
||
await provider.chat(
|
||
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o-mini")
|
||
)
|
||
|
||
_, kwargs = mock_acompletion.call_args
|
||
assert kwargs["api_base"] == "https://custom.api/v1"
|
||
assert kwargs["api_key"] == "sk-test"
|
||
assert kwargs["model"] == "openai/gpt-4o-mini"
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 9: tools 透传
|
||
# ----------------------------------------------------------------------
|
||
|
||
|
||
async def test_tools_passed_through():
|
||
"""request.tools 应透传给 litellm.acompletion。"""
|
||
provider = create_litellm_provider("openai", api_key="sk-test")
|
||
mock_acompletion = AsyncMock(return_value=_fake_response())
|
||
with patch("litellm.acompletion", new=mock_acompletion):
|
||
tools = [
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": "get_weather",
|
||
"description": "Get weather",
|
||
"parameters": {"type": "object", "properties": {}},
|
||
},
|
||
}
|
||
]
|
||
await provider.chat(
|
||
LLMRequest(
|
||
messages=[{"role": "user", "content": "weather?"}],
|
||
model="gpt-4o",
|
||
tools=tools,
|
||
tool_choice="auto",
|
||
)
|
||
)
|
||
|
||
_, kwargs = mock_acompletion.call_args
|
||
assert kwargs["tools"] == tools
|
||
assert kwargs["tool_choice"] == "auto"
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 10: tool_calls 响应解析
|
||
# ----------------------------------------------------------------------
|
||
|
||
|
||
async def test_tool_calls_in_response_parsed():
|
||
"""非流式响应中的 tool_calls 应解析成 LLMResponse.tool_calls。"""
|
||
provider = create_litellm_provider("openai", api_key="sk-test")
|
||
fake_resp = _fake_response(
|
||
content="",
|
||
tool_calls=[
|
||
_fake_tool_call(
|
||
tc_id="call_abc",
|
||
name="get_weather",
|
||
arguments='{"city": "Beijing"}',
|
||
)
|
||
],
|
||
)
|
||
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
|
||
response = await provider.chat(
|
||
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||
)
|
||
|
||
assert response.has_tool_calls
|
||
assert len(response.tool_calls) == 1
|
||
tc = response.tool_calls[0]
|
||
assert tc.id == "call_abc"
|
||
assert tc.name == "get_weather"
|
||
assert tc.arguments == {"city": "Beijing"}
|
||
|
||
|
||
async def test_tool_calls_with_dict_arguments():
|
||
"""tool_call.arguments 为 dict 时直接采用。"""
|
||
provider = create_litellm_provider("openai", api_key="sk-test")
|
||
fake_resp = _fake_response(
|
||
tool_calls=[
|
||
SimpleNamespace(
|
||
id="call_1",
|
||
function=SimpleNamespace(name="fn", arguments={"k": "v"}),
|
||
)
|
||
]
|
||
)
|
||
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
|
||
response = await provider.chat(
|
||
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||
)
|
||
|
||
assert response.tool_calls[0].arguments == {"k": "v"}
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 11: 流式 yields chunks
|
||
# ----------------------------------------------------------------------
|
||
|
||
|
||
async def test_streaming_yields_chunks():
|
||
"""流式应 yield 多个 is_final=False chunk + 一个 is_final=True 终止 chunk。"""
|
||
provider = create_litellm_provider("openai", api_key="sk-test")
|
||
|
||
chunks = [
|
||
_fake_stream_chunk(content="Hello"),
|
||
_fake_stream_chunk(content=" world"),
|
||
_fake_stream_chunk(
|
||
usage=SimpleNamespace(prompt_tokens=8, completion_tokens=2),
|
||
),
|
||
]
|
||
|
||
async def _fake_stream(**_kwargs):
|
||
for c in chunks:
|
||
yield c
|
||
|
||
with patch("litellm.acompletion", new=_fake_stream):
|
||
results = []
|
||
async for chunk in provider.chat_stream(
|
||
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o-mini")
|
||
):
|
||
results.append(chunk)
|
||
|
||
# 前 3 个是流式 chunk(is_final=False),最后一个是终止 chunk(is_final=True)
|
||
assert len(results) == 4
|
||
assert results[0].content == "Hello"
|
||
assert results[1].content == " world"
|
||
assert all(not r.is_final for r in results[:3])
|
||
assert results[-1].is_final is True
|
||
# 终止 chunk 含聚合 usage
|
||
assert results[-1].usage is not None
|
||
assert results[-1].usage.prompt_tokens == 8
|
||
assert results[-1].usage.completion_tokens == 2
|
||
|
||
|
||
async def test_streaming_aggregates_tool_calls():
|
||
"""流式 tool_calls 片段应聚合到终止 chunk。"""
|
||
provider = create_litellm_provider("openai", api_key="sk-test")
|
||
|
||
chunks = [
|
||
_fake_stream_chunk(
|
||
tool_calls_delta=[
|
||
_fake_stream_tool_call_delta(
|
||
index=0,
|
||
tc_id="call_1",
|
||
name="get_weather",
|
||
arguments_fragment='{"city":',
|
||
)
|
||
]
|
||
),
|
||
_fake_stream_chunk(
|
||
tool_calls_delta=[
|
||
_fake_stream_tool_call_delta(
|
||
index=0,
|
||
tc_id=None,
|
||
name=None,
|
||
arguments_fragment=' "Beijing"}',
|
||
)
|
||
]
|
||
),
|
||
]
|
||
|
||
async def _fake_stream(**_kwargs):
|
||
for c in chunks:
|
||
yield c
|
||
|
||
with patch("litellm.acompletion", new=_fake_stream):
|
||
results = []
|
||
async for chunk in provider.chat_stream(
|
||
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||
):
|
||
results.append(chunk)
|
||
|
||
final = results[-1]
|
||
assert final.is_final
|
||
assert len(final.tool_calls) == 1
|
||
assert final.tool_calls[0].id == "call_1"
|
||
assert final.tool_calls[0].name == "get_weather"
|
||
assert final.tool_calls[0].arguments == {"city": "Beijing"}
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 12: 流式空消息处理
|
||
# ----------------------------------------------------------------------
|
||
|
||
|
||
async def test_streaming_empty_messages_handled():
|
||
"""流式无 chunk 时仍应 yield 一个 is_final=True 空终止 chunk,不崩溃。"""
|
||
provider = create_litellm_provider("openai", api_key="sk-test")
|
||
|
||
async def _empty_stream(**_kwargs):
|
||
return
|
||
yield # 让函数成为 async generator(永不执行)
|
||
|
||
with patch("litellm.acompletion", new=_empty_stream):
|
||
results = []
|
||
async for chunk in provider.chat_stream(
|
||
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||
):
|
||
results.append(chunk)
|
||
|
||
# 至少 yield 一个终止 chunk
|
||
assert len(results) >= 1
|
||
assert results[-1].is_final is True
|
||
assert results[-1].content == ""
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 13: LiteLLM 异常 → LLMProviderError
|
||
# ----------------------------------------------------------------------
|
||
|
||
|
||
async def test_litellm_exception_raises_provider_error():
|
||
"""litellm.acompletion 抛异常时应包装成 LLMProviderError。"""
|
||
provider = create_litellm_provider("openai", api_key="sk-test")
|
||
|
||
with patch("litellm.acompletion", new=AsyncMock(side_effect=RuntimeError("boom"))):
|
||
with pytest.raises(LLMProviderError) as exc_info:
|
||
await provider.chat(
|
||
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||
)
|
||
|
||
assert "boom" in str(exc_info.value)
|
||
|
||
|
||
async def test_litellm_stream_exception_raises_provider_error():
|
||
"""流式 litellm.acompletion 抛异常时应包装成 LLMProviderError。"""
|
||
provider = create_litellm_provider("openai", api_key="sk-test")
|
||
|
||
async def _failing_stream(**_kwargs):
|
||
raise RuntimeError("stream boom")
|
||
yield # unreachable,仅为 async generator 语法
|
||
|
||
with patch("litellm.acompletion", new=_failing_stream):
|
||
with pytest.raises(LLMProviderError) as exc_info:
|
||
async for _ in provider.chat_stream(
|
||
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||
):
|
||
pass
|
||
|
||
assert "stream boom" in str(exc_info.value)
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 14: latency 非负
|
||
# ----------------------------------------------------------------------
|
||
|
||
|
||
async def test_latency_measured_non_negative():
|
||
"""LLMResponse.latency_ms 应为非负数。"""
|
||
provider = create_litellm_provider("openai", api_key="sk-test")
|
||
with patch("litellm.acompletion", new=AsyncMock(return_value=_fake_response())):
|
||
response = await provider.chat(
|
||
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||
)
|
||
|
||
assert response.latency_ms >= 0
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 额外:num_retries=0 禁用 LiteLLM 自带 retry
|
||
# ----------------------------------------------------------------------
|
||
|
||
|
||
async def test_litellm_num_retries_disabled():
|
||
"""应传 num_retries=0 禁用 LiteLLM 自带 retry(由 gateway fallback 负责)。"""
|
||
provider = create_litellm_provider("openai", api_key="sk-test")
|
||
mock_acompletion = AsyncMock(return_value=_fake_response())
|
||
with patch("litellm.acompletion", new=mock_acompletion):
|
||
await provider.chat(
|
||
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||
)
|
||
|
||
_, kwargs = mock_acompletion.call_args
|
||
assert kwargs["num_retries"] == 0
|
||
assert kwargs["stream"] is False
|
||
|
||
|
||
async def test_streaming_passes_stream_true():
|
||
"""chat_stream 应传 stream=True。"""
|
||
provider = create_litellm_provider("openai", api_key="sk-test")
|
||
|
||
captured: dict = {}
|
||
|
||
async def _capturing_stream(**kwargs):
|
||
captured.update(kwargs)
|
||
|
||
async def _inner():
|
||
return
|
||
yield
|
||
|
||
return _inner()
|
||
|
||
with patch("litellm.acompletion", new=_capturing_stream):
|
||
async for _ in provider.chat_stream(
|
||
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||
):
|
||
pass
|
||
|
||
assert captured["stream"] is True
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 额外:timeout 透传
|
||
# ----------------------------------------------------------------------
|
||
|
||
|
||
async def test_timeout_passed_through():
|
||
"""request.timeout 应透传给 litellm.acompletion。"""
|
||
provider = create_litellm_provider("openai", api_key="sk-test")
|
||
mock_acompletion = AsyncMock(return_value=_fake_response())
|
||
with patch("litellm.acompletion", new=mock_acompletion):
|
||
await provider.chat(
|
||
LLMRequest(
|
||
messages=[{"role": "user", "content": "Hi"}],
|
||
model="gpt-4o",
|
||
timeout=42.0,
|
||
)
|
||
)
|
||
|
||
_, kwargs = mock_acompletion.call_args
|
||
assert kwargs["timeout"] == 42.0
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 额外:LitellmProvider 直接构造(不经工厂)
|
||
# ----------------------------------------------------------------------
|
||
|
||
|
||
async def test_direct_litellm_provider_construction():
|
||
"""直接用 LitellmProvider(model_prefix=...) 构造也应工作。"""
|
||
provider = LitellmProvider(
|
||
model_prefix="anthropic/",
|
||
api_key="sk-test",
|
||
provider_type="anthropic",
|
||
)
|
||
fake_resp = _fake_response(content="hi", model="claude-3")
|
||
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
|
||
response = await provider.chat(
|
||
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="claude-3")
|
||
)
|
||
|
||
assert response.content == "hi"
|
||
# 验证 model 前缀正确拼接
|
||
mock_acompletion = AsyncMock(return_value=fake_resp)
|
||
with patch("litellm.acompletion", new=mock_acompletion):
|
||
await provider.chat(
|
||
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="claude-3")
|
||
)
|
||
_, kwargs = mock_acompletion.call_args
|
||
assert kwargs["model"] == "anthropic/claude-3"
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 额外:JSON tool_calls 解析容错
|
||
# ----------------------------------------------------------------------
|
||
|
||
|
||
async def test_tool_calls_invalid_json_arguments_wrapped():
|
||
"""tool_call.arguments 为非法 JSON 时应包装成 {"raw": ...} 不崩溃。"""
|
||
provider = create_litellm_provider("openai", api_key="sk-test")
|
||
fake_resp = _fake_response(
|
||
tool_calls=[
|
||
_fake_tool_call(
|
||
tc_id="call_1",
|
||
name="fn",
|
||
arguments="not-valid-json{",
|
||
)
|
||
]
|
||
)
|
||
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
|
||
response = await provider.chat(
|
||
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||
)
|
||
|
||
assert response.tool_calls[0].arguments == {"raw": "not-valid-json{"}
|