389 lines
15 KiB
Python
389 lines
15 KiB
Python
"""U15 — LiteLLM 统一 Provider 适配层。
|
||
|
||
用 LiteLLM 的 ``acompletion()`` 统一接口替换 6 个直接 API provider 适配器
|
||
(OpenAI/Anthropic/Gemini/Doubao/Wenxin/Yuanbao)。LiteLLM 内部处理各家 API
|
||
的差异(消息格式、tool_calls 格式、streaming SSE 协议),本模块只负责:
|
||
|
||
1. 把 ``LLMRequest`` 翻译成 LiteLLM ``acompletion`` 的 kwargs;
|
||
2. 把 LiteLLM 响应(OpenAI ChatCompletion 格式)翻译回 ``LLMResponse`` /
|
||
``StreamChunk``;
|
||
3. 把 LiteLLM 异常包装成 ``LLMProviderError``,保留 fallback 链可用性。
|
||
|
||
设计取舍(ponytail):
|
||
- 自建的 fallback 链 / usage tracking / 部门配额 全部保留在 ``LLMGateway``,
|
||
本 provider 不重复实现。LiteLLM 自带的 fallback / retry 在此处禁用
|
||
(``num_retries=0``),避免与 gateway 的 fallback 重复叠加导致超时放大。
|
||
- ``wenxin`` 在 LiteLLM 中无原生支持,回退到 ``openai/`` 前缀 + 自定义
|
||
``api_base``(千帆 OpenAI 兼容端点)。升级路径:LiteLLM 上游支持后切换到
|
||
``wenxin/`` 前缀。
|
||
- 旧的 6 个 provider 类(``OpenAICompatibleProvider`` 等)保留为死代码一个
|
||
release,便于回滚;新代码通过 ``create_litellm_provider`` 工厂构造。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import inspect
|
||
import json
|
||
import logging
|
||
import time
|
||
from collections.abc import AsyncGenerator, Iterable
|
||
|
||
from agentkit.core.exceptions import LLMProviderError
|
||
from agentkit.llm.protocol import (
|
||
LLMProvider,
|
||
LLMRequest,
|
||
LLMResponse,
|
||
StreamChunk,
|
||
TokenUsage,
|
||
ToolCall,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# provider_type → LiteLLM model 前缀映射。
|
||
# LiteLLM 通过 model 字符串前缀路由到对应 SDK:``openai/gpt-4o``、
|
||
# ``anthropic/claude-...``、``gemini/gemini-...``、``volcengine/...``、
|
||
# ``hunyuan/...``。未知类型回退到 ``openai/``(最大兼容性)。
|
||
_PROVIDER_TYPE_TO_PREFIX: dict[str, str] = {
|
||
"openai": "openai/",
|
||
"anthropic": "anthropic/",
|
||
"gemini": "gemini/",
|
||
# 豆包由火山引擎提供,LiteLLM 前缀为 volcengine/
|
||
"doubao": "volcengine/",
|
||
# ponytail: LiteLLM 暂无 wenxin 原生支持,回退到 openai/ + 自定义 api_base
|
||
# (千帆 OpenAI 兼容端点)。ceiling: wenxin 专属参数(如 AK/SK 鉴权)不支持;
|
||
# 升级路径:上游支持后改 "wenxin/"。
|
||
"wenxin": "openai/",
|
||
# 腾讯混元
|
||
"yuanbao": "hunyuan/",
|
||
}
|
||
|
||
|
||
def _model_prefix_for(provider_type: str) -> str:
|
||
"""根据 provider_type 返回 LiteLLM model 前缀;未知类型回退到 openai/。"""
|
||
return _PROVIDER_TYPE_TO_PREFIX.get(provider_type, "openai/")
|
||
|
||
|
||
class LitellmProvider(LLMProvider):
|
||
"""基于 LiteLLM ``acompletion`` 的统一 Provider。
|
||
|
||
一个实例对应一个 provider 配置(api_key + base_url + provider_type)。
|
||
``chat`` / ``chat_stream`` 把 ``LLMRequest`` 转发给 LiteLLM 并翻译响应。
|
||
|
||
注意:本类不持有 httpx 客户端 — LiteLLM 内部管理连接池。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
model_prefix: str,
|
||
api_key: str,
|
||
base_url: str | None = None,
|
||
provider_type: str = "openai",
|
||
**default_kwargs: object,
|
||
) -> None:
|
||
self._model_prefix = model_prefix
|
||
self._api_key = api_key
|
||
self._base_url = base_url or None # 空字符串视作未设置
|
||
self._provider_type = provider_type
|
||
self._default_kwargs: dict[str, object] = dict(default_kwargs)
|
||
|
||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||
"""非流式 chat — 调用 ``litellm.acompletion`` 并翻译响应。"""
|
||
import litellm
|
||
|
||
kwargs = self._build_kwargs(request, stream=False)
|
||
|
||
start = time.monotonic()
|
||
try:
|
||
response = await litellm.acompletion(**kwargs)
|
||
except Exception as e:
|
||
# LiteLLM 抛出各种异常(openai.APIError / anthropic.APIError / 自定义),
|
||
# 统一包装成 LLMProviderError 以便 gateway fallback 链识别。
|
||
raise LLMProviderError(self._provider_type, str(e)) from e
|
||
latency_ms = (time.monotonic() - start) * 1000
|
||
|
||
return self._parse_response(response, request.model, latency_ms)
|
||
|
||
async def chat_stream(self, request: LLMRequest) -> AsyncGenerator[StreamChunk, None]:
|
||
"""流式 chat — 调用 ``litellm.acompletion(stream=True)`` 并翻译 chunks。
|
||
|
||
异步生成器安全:本项目规则禁止在第一个 ``yield`` 前使用 ``return``,
|
||
本函数无早退分支,无需 ``return; yield`` 守卫。
|
||
"""
|
||
import litellm
|
||
|
||
kwargs = self._build_kwargs(request, stream=True)
|
||
|
||
accumulated_tool_calls: dict[int, dict[str, object]] = {}
|
||
final_usage: TokenUsage | None = None
|
||
final_model: str = request.model
|
||
|
||
try:
|
||
# litellm.acompletion(stream=True) 的返回类型取决于版本 / 调用方式:
|
||
# - 真实 litellm:返回 coroutine,await 后得到 async generator;
|
||
# - 测试 mock(async def + yield):调用即返回 async generator,不可 await。
|
||
# 用 isawaitable 兼容两种路径。
|
||
raw = litellm.acompletion(**kwargs)
|
||
stream = await raw if inspect.isawaitable(raw) else raw
|
||
async for chunk in stream:
|
||
parsed = self._parse_stream_chunk(
|
||
chunk,
|
||
request.model,
|
||
accumulated_tool_calls,
|
||
)
|
||
# 更新累计状态
|
||
if parsed.usage is not None:
|
||
final_usage = parsed.usage
|
||
if parsed.model:
|
||
final_model = parsed.model
|
||
# 内容块(含空内容)yield;usage-only 块也 yield(部分 provider
|
||
# 在最后一个 chunk 才给 usage,内容为空)。
|
||
yield parsed
|
||
|
||
# 流结束:yield 终止 chunk(聚合 tool_calls + usage)
|
||
tool_calls_list = self._finalize_tool_calls(accumulated_tool_calls)
|
||
yield StreamChunk(
|
||
content="",
|
||
model=final_model,
|
||
tool_calls=tool_calls_list,
|
||
usage=final_usage,
|
||
is_final=True,
|
||
)
|
||
except Exception as e:
|
||
raise LLMProviderError(self._provider_type, str(e)) from e
|
||
|
||
# ------------------------------------------------------------------
|
||
# 内部辅助
|
||
# ------------------------------------------------------------------
|
||
|
||
def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, object]:
|
||
"""从 LLMRequest 构造 litellm.acompletion kwargs。"""
|
||
kwargs: dict[str, object] = {
|
||
"model": f"{self._model_prefix}{request.model}",
|
||
"messages": request.messages,
|
||
"temperature": request.temperature,
|
||
"max_tokens": request.max_tokens,
|
||
"api_key": self._api_key,
|
||
"stream": stream,
|
||
# 禁用 LiteLLM 自带 retry — 由 gateway 的 fallback 链 / RetryPolicy 负责
|
||
"num_retries": 0,
|
||
}
|
||
if self._base_url:
|
||
kwargs["api_base"] = self._base_url
|
||
if request.tools:
|
||
kwargs["tools"] = request.tools
|
||
kwargs["tool_choice"] = request.tool_choice
|
||
if request.timeout is not None:
|
||
kwargs["timeout"] = request.timeout
|
||
# U17 — 透传 LiteLLM cache 参数(cache_key 或 no-cache)到 litellm.acompletion
|
||
cache_params = request._cache
|
||
if cache_params is not None:
|
||
kwargs["cache"] = cache_params
|
||
# 合并构造时传入的默认 kwargs(如 max_connections 等provider特定参数)
|
||
kwargs.update(self._default_kwargs)
|
||
return kwargs
|
||
|
||
def _parse_response(
|
||
self,
|
||
response: object,
|
||
request_model: str,
|
||
latency_ms: float,
|
||
) -> LLMResponse:
|
||
"""把 litellm 响应(OpenAI ChatCompletion 格式)翻译成 LLMResponse。"""
|
||
# LiteLLM 响应统一为 OpenAI ChatCompletion 格式:
|
||
# response.choices[0].message.{content, tool_calls}
|
||
# response.usage.{prompt_tokens, completion_tokens}
|
||
# response.model
|
||
choices = getattr(response, "choices", None) or []
|
||
content = ""
|
||
tool_calls: list[ToolCall] = []
|
||
if choices:
|
||
message = getattr(choices[0], "message", None)
|
||
if message is not None:
|
||
content = getattr(message, "content", None) or ""
|
||
raw_tool_calls = getattr(message, "tool_calls", None)
|
||
if raw_tool_calls:
|
||
tool_calls = _parse_tool_calls(raw_tool_calls)
|
||
|
||
usage_obj = getattr(response, "usage", None)
|
||
usage = _parse_usage(usage_obj) if usage_obj is not None else TokenUsage()
|
||
|
||
model_name = getattr(response, "model", None) or request_model
|
||
|
||
# U17 — 检测 LiteLLM 缓存命中(_hidden_params 含 cache_key 或 cache_hit)
|
||
cache_hit = False
|
||
hidden = getattr(response, "_hidden_params", None)
|
||
if isinstance(hidden, dict) and ("cache_key" in hidden or hidden.get("cache_hit")):
|
||
cache_hit = True
|
||
|
||
return LLMResponse(
|
||
content=content,
|
||
model=model_name,
|
||
usage=usage,
|
||
tool_calls=tool_calls,
|
||
latency_ms=latency_ms,
|
||
cache_hit=cache_hit,
|
||
)
|
||
|
||
def _parse_stream_chunk(
|
||
self,
|
||
chunk: object,
|
||
request_model: str,
|
||
accumulated_tool_calls: dict[int, dict[str, object]],
|
||
) -> StreamChunk:
|
||
"""解析单个流式 chunk(非 final)。累计 tool_calls 到传入字典。"""
|
||
choices = getattr(chunk, "choices", None) or []
|
||
content = ""
|
||
model_name = getattr(chunk, "model", None) or request_model
|
||
usage: TokenUsage | None = None
|
||
|
||
if choices:
|
||
delta = getattr(choices[0], "delta", None)
|
||
if delta is not None:
|
||
content = getattr(delta, "content", None) or ""
|
||
raw_tool_calls = getattr(delta, "tool_calls", None)
|
||
if raw_tool_calls:
|
||
_accumulate_stream_tool_calls(raw_tool_calls, accumulated_tool_calls)
|
||
|
||
# 部分 provider 在最后一个 chunk 附带 usage
|
||
usage_obj = getattr(chunk, "usage", None)
|
||
if usage_obj is not None:
|
||
usage = _parse_usage(usage_obj)
|
||
|
||
return StreamChunk(
|
||
content=content,
|
||
model=model_name,
|
||
tool_calls=[], # 流式中间块不带 tool_calls,由 final chunk 聚合
|
||
usage=usage,
|
||
is_final=False,
|
||
)
|
||
|
||
def _finalize_tool_calls(
|
||
self,
|
||
accumulated: dict[int, dict[str, object]],
|
||
) -> list[ToolCall]:
|
||
"""把累计的流式 tool_calls 字典转成 ToolCall 列表。"""
|
||
tool_calls: list[ToolCall] = []
|
||
for idx in sorted(accumulated.keys()):
|
||
tc_data = accumulated[idx]
|
||
args_str = tc_data.get("arguments_str", "")
|
||
try:
|
||
arguments = json.loads(args_str) if args_str else {}
|
||
except json.JSONDecodeError:
|
||
arguments = {"raw": args_str}
|
||
tool_calls.append(
|
||
ToolCall(
|
||
id=tc_data.get("id", ""),
|
||
name=tc_data.get("name", ""),
|
||
arguments=arguments,
|
||
)
|
||
)
|
||
return tool_calls
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 响应解析辅助(模块级,便于测试 mock)
|
||
# ----------------------------------------------------------------------
|
||
|
||
|
||
def _parse_tool_calls(raw_tool_calls: Iterable[object]) -> list[ToolCall]:
|
||
"""解析非流式响应的 tool_calls(OpenAI 格式 list[ChoiceMessageToolCall])。"""
|
||
result: list[ToolCall] = []
|
||
for tc in raw_tool_calls:
|
||
# LiteLLM 返回的对象有 .id / .function.{name,arguments}
|
||
tc_id = getattr(tc, "id", "") or ""
|
||
func = getattr(tc, "function", None)
|
||
if func is None:
|
||
continue
|
||
name = getattr(func, "name", "") or ""
|
||
args = getattr(func, "arguments", "{}")
|
||
if isinstance(args, str):
|
||
try:
|
||
args_dict = json.loads(args) if args else {}
|
||
except json.JSONDecodeError:
|
||
args_dict = {"raw": args}
|
||
elif isinstance(args, dict):
|
||
args_dict = args
|
||
else:
|
||
args_dict = {"raw": str(args)}
|
||
result.append(ToolCall(id=tc_id, name=name, arguments=args_dict))
|
||
return result
|
||
|
||
|
||
def _parse_usage(usage_obj: object) -> TokenUsage:
|
||
"""解析 usage 对象(OpenAI CompletionUsage 或 dict)。"""
|
||
prompt = getattr(usage_obj, "prompt_tokens", None)
|
||
completion = getattr(usage_obj, "completion_tokens", None)
|
||
if prompt is None and isinstance(usage_obj, dict):
|
||
prompt = usage_obj.get("prompt_tokens", 0)
|
||
completion = usage_obj.get("completion_tokens", 0)
|
||
return TokenUsage(
|
||
prompt_tokens=int(prompt or 0),
|
||
completion_tokens=int(completion or 0),
|
||
)
|
||
|
||
|
||
def _accumulate_stream_tool_calls(
|
||
raw_tool_calls: Iterable[object],
|
||
accumulated: dict[int, dict[str, object]],
|
||
) -> None:
|
||
"""累计流式 chunk 里的 tool_calls 片段(OpenAI delta.tool_calls 格式)。
|
||
|
||
每个 delta tool_call 有 index/id/function.{name,arguments},arguments
|
||
是分片字符串,需跨 chunk 拼接。
|
||
"""
|
||
for tc in raw_tool_calls:
|
||
idx = getattr(tc, "index", 0) or 0
|
||
if idx not in accumulated:
|
||
accumulated[idx] = {
|
||
"id": "",
|
||
"name": "",
|
||
"arguments_str": "",
|
||
}
|
||
tc_id = getattr(tc, "id", None)
|
||
if tc_id:
|
||
accumulated[idx]["id"] = tc_id
|
||
func = getattr(tc, "function", None)
|
||
if func is not None:
|
||
name = getattr(func, "name", None)
|
||
if name:
|
||
accumulated[idx]["name"] = name
|
||
args_fragment = getattr(func, "arguments", None)
|
||
if args_fragment:
|
||
accumulated[idx]["arguments_str"] += args_fragment
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 工厂函数
|
||
# ----------------------------------------------------------------------
|
||
|
||
|
||
def create_litellm_provider(
|
||
provider_type: str,
|
||
api_key: str,
|
||
base_url: str | None = None,
|
||
**kwargs: object,
|
||
) -> LitellmProvider:
|
||
"""根据 provider_type 创建 LitellmProvider 实例。
|
||
|
||
Args:
|
||
provider_type: 配置中的 provider 类型字符串
|
||
("openai"|"anthropic"|"gemini"|"doubao"|"wenxin"|"yuanbao")。
|
||
未知类型回退到 "openai/" 前缀。
|
||
api_key: provider API key。
|
||
base_url: 可选自定义 API base URL。
|
||
**kwargs: 透传给 LitellmProvider 的额外默认参数。
|
||
|
||
Returns:
|
||
配置好 model_prefix 的 LitellmProvider。
|
||
"""
|
||
prefix = _model_prefix_for(provider_type)
|
||
return LitellmProvider(
|
||
model_prefix=prefix,
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
provider_type=provider_type,
|
||
**kwargs,
|
||
)
|