64 lines
1.7 KiB
Python
64 lines
1.7 KiB
Python
"""DoubaoProvider - 字节豆包 Provider
|
||
|
||
支持豆包 1.6 Pro/Lite 系列模型。
|
||
API:火山引擎 OpenAI 兼容接口
|
||
鉴权:Bearer API Key(火山引擎 IAM)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from typing import Any
|
||
|
||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 豆包模型映射
|
||
DOUBAO_MODEL_MAP = {
|
||
"doubao-pro-32k": "doubao-pro-32k",
|
||
"doubao-pro-128k": "doubao-pro-128k",
|
||
"doubao-lite-32k": "doubao-lite-32k",
|
||
"doubao-lite-128k": "doubao-lite-128k",
|
||
"doubao-vision": "doubao-vision",
|
||
}
|
||
|
||
# 火山引擎 API base URL
|
||
DOUBAO_DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
|
||
|
||
|
||
class DoubaoProvider(OpenAICompatibleProvider):
|
||
"""字节豆包 Provider
|
||
|
||
通过火山引擎 OpenAI 兼容接口调用豆包模型。
|
||
|
||
使用方式:
|
||
provider = DoubaoProvider(
|
||
api_key="your_ark_api_key",
|
||
# 可选:指定推理接入点 ID 作为 default_model
|
||
default_model="doubao-pro-32k",
|
||
)
|
||
|
||
注意:火山引擎需要在控制台创建"推理接入点"获取 Service ID,
|
||
也可以直接使用模型名称作为 endpoint_id。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: str,
|
||
base_url: str = DOUBAO_DEFAULT_BASE_URL,
|
||
default_model: str = "doubao-pro-32k",
|
||
**kwargs: Any,
|
||
):
|
||
super().__init__(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
default_model=default_model,
|
||
**kwargs,
|
||
)
|
||
|
||
async def chat(self, request):
|
||
"""发送 chat 请求,处理豆包模型映射"""
|
||
request.model = DOUBAO_MODEL_MAP.get(request.model, request.model)
|
||
return await super().chat(request)
|