fischer-agentkit/src/agentkit/llm/providers/litellm_provider.py

389 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U15 — LiteLLM 统一 Provider 适配层。
用 LiteLLM 的 ``acompletion()`` 统一接口替换 6 个直接 API provider 适配器
OpenAI/Anthropic/Gemini/Doubao/Wenxin/Yuanbao。LiteLLM 内部处理各家 API
的差异消息格式、tool_calls 格式、streaming SSE 协议),本模块只负责:
1. 把 ``LLMRequest`` 翻译成 LiteLLM ``acompletion`` 的 kwargs
2. 把 LiteLLM 响应OpenAI ChatCompletion 格式)翻译回 ``LLMResponse`` /
``StreamChunk``
3. 把 LiteLLM 异常包装成 ``LLMProviderError``,保留 fallback 链可用性。
设计取舍ponytail
- 自建的 fallback 链 / usage tracking / 部门配额 全部保留在 ``LLMGateway``
本 provider 不重复实现。LiteLLM 自带的 fallback / retry 在此处禁用
``num_retries=0``),避免与 gateway 的 fallback 重复叠加导致超时放大。
- ``wenxin`` 在 LiteLLM 中无原生支持,回退到 ``openai/`` 前缀 + 自定义
``api_base``(千帆 OpenAI 兼容端点。升级路径LiteLLM 上游支持后切换到
``wenxin/`` 前缀。
- 旧的 6 个 provider 类(``OpenAICompatibleProvider`` 等)保留为死代码一个
release便于回滚新代码通过 ``create_litellm_provider`` 工厂构造。
"""
from __future__ import annotations
import inspect
import json
import logging
import time
from collections.abc import AsyncGenerator, Iterable
from agentkit.core.exceptions import LLMProviderError
from agentkit.llm.protocol import (
LLMProvider,
LLMRequest,
LLMResponse,
StreamChunk,
TokenUsage,
ToolCall,
)
logger = logging.getLogger(__name__)
# provider_type → LiteLLM model 前缀映射。
# LiteLLM 通过 model 字符串前缀路由到对应 SDK``openai/gpt-4o``、
# ``anthropic/claude-...``、``gemini/gemini-...``、``volcengine/...``、
# ``hunyuan/...``。未知类型回退到 ``openai/``(最大兼容性)。
_PROVIDER_TYPE_TO_PREFIX: dict[str, str] = {
"openai": "openai/",
"anthropic": "anthropic/",
"gemini": "gemini/",
# 豆包由火山引擎提供LiteLLM 前缀为 volcengine/
"doubao": "volcengine/",
# ponytail: LiteLLM 暂无 wenxin 原生支持,回退到 openai/ + 自定义 api_base
# (千帆 OpenAI 兼容端点。ceiling: wenxin 专属参数(如 AK/SK 鉴权)不支持;
# 升级路径:上游支持后改 "wenxin/"。
"wenxin": "openai/",
# 腾讯混元
"yuanbao": "hunyuan/",
}
def _model_prefix_for(provider_type: str) -> str:
"""根据 provider_type 返回 LiteLLM model 前缀;未知类型回退到 openai/。"""
return _PROVIDER_TYPE_TO_PREFIX.get(provider_type, "openai/")
class LitellmProvider(LLMProvider):
"""基于 LiteLLM ``acompletion`` 的统一 Provider。
一个实例对应一个 provider 配置api_key + base_url + provider_type
``chat`` / ``chat_stream`` 把 ``LLMRequest`` 转发给 LiteLLM 并翻译响应。
注意:本类不持有 httpx 客户端 — LiteLLM 内部管理连接池。
"""
def __init__(
self,
model_prefix: str,
api_key: str,
base_url: str | None = None,
provider_type: str = "openai",
**default_kwargs: object,
) -> None:
self._model_prefix = model_prefix
self._api_key = api_key
self._base_url = base_url or None # 空字符串视作未设置
self._provider_type = provider_type
self._default_kwargs: dict[str, object] = dict(default_kwargs)
async def chat(self, request: LLMRequest) -> LLMResponse:
"""非流式 chat — 调用 ``litellm.acompletion`` 并翻译响应。"""
import litellm
kwargs = self._build_kwargs(request, stream=False)
start = time.monotonic()
try:
response = await litellm.acompletion(**kwargs)
except Exception as e:
# LiteLLM 抛出各种异常openai.APIError / anthropic.APIError / 自定义),
# 统一包装成 LLMProviderError 以便 gateway fallback 链识别。
raise LLMProviderError(self._provider_type, str(e)) from e
latency_ms = (time.monotonic() - start) * 1000
return self._parse_response(response, request.model, latency_ms)
async def chat_stream(self, request: LLMRequest) -> AsyncGenerator[StreamChunk, None]:
"""流式 chat — 调用 ``litellm.acompletion(stream=True)`` 并翻译 chunks。
异步生成器安全:本项目规则禁止在第一个 ``yield`` 前使用 ``return``
本函数无早退分支,无需 ``return; yield`` 守卫。
"""
import litellm
kwargs = self._build_kwargs(request, stream=True)
accumulated_tool_calls: dict[int, dict[str, object]] = {}
final_usage: TokenUsage | None = None
final_model: str = request.model
try:
# litellm.acompletion(stream=True) 的返回类型取决于版本 / 调用方式:
# - 真实 litellm返回 coroutineawait 后得到 async generator
# - 测试 mockasync def + yield调用即返回 async generator不可 await。
# 用 isawaitable 兼容两种路径。
raw = litellm.acompletion(**kwargs)
stream = await raw if inspect.isawaitable(raw) else raw
async for chunk in stream:
parsed = self._parse_stream_chunk(
chunk,
request.model,
accumulated_tool_calls,
)
# 更新累计状态
if parsed.usage is not None:
final_usage = parsed.usage
if parsed.model:
final_model = parsed.model
# 内容块含空内容yieldusage-only 块也 yield部分 provider
# 在最后一个 chunk 才给 usage内容为空
yield parsed
# 流结束yield 终止 chunk聚合 tool_calls + usage
tool_calls_list = self._finalize_tool_calls(accumulated_tool_calls)
yield StreamChunk(
content="",
model=final_model,
tool_calls=tool_calls_list,
usage=final_usage,
is_final=True,
)
except Exception as e:
raise LLMProviderError(self._provider_type, str(e)) from e
# ------------------------------------------------------------------
# 内部辅助
# ------------------------------------------------------------------
def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, object]:
"""从 LLMRequest 构造 litellm.acompletion kwargs。"""
kwargs: dict[str, object] = {
"model": f"{self._model_prefix}{request.model}",
"messages": request.messages,
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"api_key": self._api_key,
"stream": stream,
# 禁用 LiteLLM 自带 retry — 由 gateway 的 fallback 链 / RetryPolicy 负责
"num_retries": 0,
}
if self._base_url:
kwargs["api_base"] = self._base_url
if request.tools:
kwargs["tools"] = request.tools
kwargs["tool_choice"] = request.tool_choice
if request.timeout is not None:
kwargs["timeout"] = request.timeout
# U17 — 透传 LiteLLM cache 参数cache_key 或 no-cache到 litellm.acompletion
cache_params = request._cache
if cache_params is not None:
kwargs["cache"] = cache_params
# 合并构造时传入的默认 kwargs如 max_connections 等provider特定参数
kwargs.update(self._default_kwargs)
return kwargs
def _parse_response(
self,
response: object,
request_model: str,
latency_ms: float,
) -> LLMResponse:
"""把 litellm 响应OpenAI ChatCompletion 格式)翻译成 LLMResponse。"""
# LiteLLM 响应统一为 OpenAI ChatCompletion 格式:
# response.choices[0].message.{content, tool_calls}
# response.usage.{prompt_tokens, completion_tokens}
# response.model
choices = getattr(response, "choices", None) or []
content = ""
tool_calls: list[ToolCall] = []
if choices:
message = getattr(choices[0], "message", None)
if message is not None:
content = getattr(message, "content", None) or ""
raw_tool_calls = getattr(message, "tool_calls", None)
if raw_tool_calls:
tool_calls = _parse_tool_calls(raw_tool_calls)
usage_obj = getattr(response, "usage", None)
usage = _parse_usage(usage_obj) if usage_obj is not None else TokenUsage()
model_name = getattr(response, "model", None) or request_model
# U17 — 检测 LiteLLM 缓存命中_hidden_params 含 cache_key 或 cache_hit
cache_hit = False
hidden = getattr(response, "_hidden_params", None)
if isinstance(hidden, dict) and ("cache_key" in hidden or hidden.get("cache_hit")):
cache_hit = True
return LLMResponse(
content=content,
model=model_name,
usage=usage,
tool_calls=tool_calls,
latency_ms=latency_ms,
cache_hit=cache_hit,
)
def _parse_stream_chunk(
self,
chunk: object,
request_model: str,
accumulated_tool_calls: dict[int, dict[str, object]],
) -> StreamChunk:
"""解析单个流式 chunk非 final。累计 tool_calls 到传入字典。"""
choices = getattr(chunk, "choices", None) or []
content = ""
model_name = getattr(chunk, "model", None) or request_model
usage: TokenUsage | None = None
if choices:
delta = getattr(choices[0], "delta", None)
if delta is not None:
content = getattr(delta, "content", None) or ""
raw_tool_calls = getattr(delta, "tool_calls", None)
if raw_tool_calls:
_accumulate_stream_tool_calls(raw_tool_calls, accumulated_tool_calls)
# 部分 provider 在最后一个 chunk 附带 usage
usage_obj = getattr(chunk, "usage", None)
if usage_obj is not None:
usage = _parse_usage(usage_obj)
return StreamChunk(
content=content,
model=model_name,
tool_calls=[], # 流式中间块不带 tool_calls由 final chunk 聚合
usage=usage,
is_final=False,
)
def _finalize_tool_calls(
self,
accumulated: dict[int, dict[str, object]],
) -> list[ToolCall]:
"""把累计的流式 tool_calls 字典转成 ToolCall 列表。"""
tool_calls: list[ToolCall] = []
for idx in sorted(accumulated.keys()):
tc_data = accumulated[idx]
args_str = tc_data.get("arguments_str", "")
try:
arguments = json.loads(args_str) if args_str else {}
except json.JSONDecodeError:
arguments = {"raw": args_str}
tool_calls.append(
ToolCall(
id=tc_data.get("id", ""),
name=tc_data.get("name", ""),
arguments=arguments,
)
)
return tool_calls
# ----------------------------------------------------------------------
# 响应解析辅助(模块级,便于测试 mock
# ----------------------------------------------------------------------
def _parse_tool_calls(raw_tool_calls: Iterable[object]) -> list[ToolCall]:
"""解析非流式响应的 tool_callsOpenAI 格式 list[ChoiceMessageToolCall])。"""
result: list[ToolCall] = []
for tc in raw_tool_calls:
# LiteLLM 返回的对象有 .id / .function.{name,arguments}
tc_id = getattr(tc, "id", "") or ""
func = getattr(tc, "function", None)
if func is None:
continue
name = getattr(func, "name", "") or ""
args = getattr(func, "arguments", "{}")
if isinstance(args, str):
try:
args_dict = json.loads(args) if args else {}
except json.JSONDecodeError:
args_dict = {"raw": args}
elif isinstance(args, dict):
args_dict = args
else:
args_dict = {"raw": str(args)}
result.append(ToolCall(id=tc_id, name=name, arguments=args_dict))
return result
def _parse_usage(usage_obj: object) -> TokenUsage:
"""解析 usage 对象OpenAI CompletionUsage 或 dict"""
prompt = getattr(usage_obj, "prompt_tokens", None)
completion = getattr(usage_obj, "completion_tokens", None)
if prompt is None and isinstance(usage_obj, dict):
prompt = usage_obj.get("prompt_tokens", 0)
completion = usage_obj.get("completion_tokens", 0)
return TokenUsage(
prompt_tokens=int(prompt or 0),
completion_tokens=int(completion or 0),
)
def _accumulate_stream_tool_calls(
raw_tool_calls: Iterable[object],
accumulated: dict[int, dict[str, object]],
) -> None:
"""累计流式 chunk 里的 tool_calls 片段OpenAI delta.tool_calls 格式)。
每个 delta tool_call 有 index/id/function.{name,arguments}arguments
是分片字符串,需跨 chunk 拼接。
"""
for tc in raw_tool_calls:
idx = getattr(tc, "index", 0) or 0
if idx not in accumulated:
accumulated[idx] = {
"id": "",
"name": "",
"arguments_str": "",
}
tc_id = getattr(tc, "id", None)
if tc_id:
accumulated[idx]["id"] = tc_id
func = getattr(tc, "function", None)
if func is not None:
name = getattr(func, "name", None)
if name:
accumulated[idx]["name"] = name
args_fragment = getattr(func, "arguments", None)
if args_fragment:
accumulated[idx]["arguments_str"] += args_fragment
# ----------------------------------------------------------------------
# 工厂函数
# ----------------------------------------------------------------------
def create_litellm_provider(
provider_type: str,
api_key: str,
base_url: str | None = None,
**kwargs: object,
) -> LitellmProvider:
"""根据 provider_type 创建 LitellmProvider 实例。
Args:
provider_type: 配置中的 provider 类型字符串
"openai"|"anthropic"|"gemini"|"doubao"|"wenxin"|"yuanbao")。
未知类型回退到 "openai/" 前缀。
api_key: provider API key。
base_url: 可选自定义 API base URL。
**kwargs: 透传给 LitellmProvider 的额外默认参数。
Returns:
配置好 model_prefix 的 LitellmProvider。
"""
prefix = _model_prefix_for(provider_type)
return LitellmProvider(
model_prefix=prefix,
api_key=api_key,
base_url=base_url,
provider_type=provider_type,
**kwargs,
)