fischer-agentkit/src/agentkit/llm/providers/doubao.py

64 lines
1.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""DoubaoProvider - 字节豆包 Provider
支持豆包 1.6 Pro/Lite 系列模型。
API火山引擎 OpenAI 兼容接口
鉴权Bearer API Key火山引擎 IAM
"""
from __future__ import annotations
import logging
from typing import Any
from agentkit.llm.providers.openai import OpenAICompatibleProvider
logger = logging.getLogger(__name__)
# 豆包模型映射
DOUBAO_MODEL_MAP = {
"doubao-pro-32k": "doubao-pro-32k",
"doubao-pro-128k": "doubao-pro-128k",
"doubao-lite-32k": "doubao-lite-32k",
"doubao-lite-128k": "doubao-lite-128k",
"doubao-vision": "doubao-vision",
}
# 火山引擎 API base URL
DOUBAO_DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
class DoubaoProvider(OpenAICompatibleProvider):
"""字节豆包 Provider
通过火山引擎 OpenAI 兼容接口调用豆包模型。
使用方式:
provider = DoubaoProvider(
api_key="your_ark_api_key",
# 可选:指定推理接入点 ID 作为 default_model
default_model="doubao-pro-32k",
)
注意:火山引擎需要在控制台创建"推理接入点"获取 Service ID
也可以直接使用模型名称作为 endpoint_id。
"""
def __init__(
self,
api_key: str,
base_url: str = DOUBAO_DEFAULT_BASE_URL,
default_model: str = "doubao-pro-32k",
**kwargs: Any,
):
super().__init__(
api_key=api_key,
base_url=base_url,
default_model=default_model,
**kwargs,
)
async def chat(self, request):
"""发送 chat 请求,处理豆包模型映射"""
request.model = DOUBAO_MODEL_MAP.get(request.model, request.model)
return await super().chat(request)