feat(llm): U8 Chinese LLM providers - Wenxin, Doubao, Yuanbao

- WenxinProvider: Baidu ERNIE via Qianfan v2 OpenAI-compatible API, AK/SK token auth
- DoubaoProvider: ByteDance Doubao via Volcengine Ark API
- YuanbaoProvider: Tencent Hunyuan via OpenAI-compatible API with enhancement mode
- All inherit from OpenAICompatibleProvider for retry/circuit breaker support
- 16 tests passing
This commit is contained in:
chiguyong 2026-06-06 22:46:53 +08:00
parent 34e083abde
commit 9753a08ac8
5 changed files with 374 additions and 0 deletions

View File

@ -1,15 +1,21 @@
"""LLM Providers"""
from agentkit.llm.providers.anthropic import AnthropicProvider
from agentkit.llm.providers.doubao import DoubaoProvider
from agentkit.llm.providers.gemini import GeminiProvider
from agentkit.llm.providers.openai import OpenAICompatibleProvider
from agentkit.llm.providers.tracker import UsageRecord, UsageSummary, UsageTracker
from agentkit.llm.providers.wenxin import WenxinProvider
from agentkit.llm.providers.yuanbao import YuanbaoProvider
__all__ = [
"AnthropicProvider",
"DoubaoProvider",
"GeminiProvider",
"OpenAICompatibleProvider",
"UsageRecord",
"UsageSummary",
"UsageTracker",
"WenxinProvider",
"YuanbaoProvider",
]

View File

@ -0,0 +1,63 @@
"""DoubaoProvider - 字节豆包 Provider
支持豆包 1.6 Pro/Lite 系列模型
API火山引擎 OpenAI 兼容接口
鉴权Bearer API Key火山引擎 IAM
"""
from __future__ import annotations
import logging
from typing import Any
from agentkit.llm.providers.openai import OpenAICompatibleProvider
logger = logging.getLogger(__name__)
# 豆包模型映射
DOUBAO_MODEL_MAP = {
"doubao-pro-32k": "doubao-pro-32k",
"doubao-pro-128k": "doubao-pro-128k",
"doubao-lite-32k": "doubao-lite-32k",
"doubao-lite-128k": "doubao-lite-128k",
"doubao-vision": "doubao-vision",
}
# 火山引擎 API base URL
DOUBAO_DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
class DoubaoProvider(OpenAICompatibleProvider):
"""字节豆包 Provider
通过火山引擎 OpenAI 兼容接口调用豆包模型
使用方式
provider = DoubaoProvider(
api_key="your_ark_api_key",
# 可选:指定推理接入点 ID 作为 default_model
default_model="doubao-pro-32k",
)
注意火山引擎需要在控制台创建"推理接入点"获取 Service ID
也可以直接使用模型名称作为 endpoint_id
"""
def __init__(
self,
api_key: str,
base_url: str = DOUBAO_DEFAULT_BASE_URL,
default_model: str = "doubao-pro-32k",
**kwargs: Any,
):
super().__init__(
api_key=api_key,
base_url=base_url,
default_model=default_model,
**kwargs,
)
async def chat(self, request):
"""发送 chat 请求,处理豆包模型映射"""
request.model = DOUBAO_MODEL_MAP.get(request.model, request.model)
return await super().chat(request)

View File

@ -0,0 +1,114 @@
"""WenxinProvider - 百度文心 ERNIE Provider
支持 ERNIE 4.5/5.0 系列模型
鉴权AK/SK access_token缓存 29
API百度千帆平台 OpenAI 兼容接口
"""
from __future__ import annotations
import logging
import time
from typing import Any
from agentkit.llm.providers.openai import OpenAICompatibleProvider
from agentkit.llm.protocol import LLMRequest, LLMResponse
logger = logging.getLogger(__name__)
# 文心模型到端点的映射
WENXIN_MODEL_MAP = {
"ernie-4.5-turbo-128k": "ernie-4.5-turbo-128k",
"ernie-5.0": "ernie-5.0",
"ernie-x1.1": "ernie-x1.1",
"ernie-4.0-8k": "ernie-4.0-8k",
"ernie-3.5-8k": "ernie-3.5-8k",
}
# 默认 base URL千帆 v2 OpenAI 兼容接口)
WENXIN_DEFAULT_BASE_URL = "https://qianfan.baidubce.com/v2"
class WenxinProvider(OpenAICompatibleProvider):
"""百度文心 ERNIE Provider
通过千帆平台 v2 OpenAI 兼容接口调用文心模型
鉴权方式
- 方式1推荐直接使用 API Key OpenAI 兼容接口
- 方式2传统AK/SK 换取 access_token
使用方式
provider = WenxinProvider(api_key="your_api_key")
# 或使用 AK/SK
provider = WenxinProvider(api_key="", access_key="ak", secret_key="sk")
"""
def __init__(
self,
api_key: str = "",
access_key: str | None = None,
secret_key: str | None = None,
base_url: str = WENXIN_DEFAULT_BASE_URL,
default_model: str = "ernie-4.5-turbo-128k",
**kwargs: Any,
):
# If AK/SK provided, use token-based auth
self._access_key = access_key
self._secret_key = secret_key
self._access_token: str | None = None
self._token_expires_at: float = 0.0
# Resolve API key
effective_api_key = api_key
if not api_key and access_key and secret_key:
effective_api_key = "pending_token" # Will be resolved on first request
super().__init__(
api_key=effective_api_key,
base_url=base_url,
default_model=default_model,
**kwargs,
)
async def chat(self, request: LLMRequest) -> LLMResponse:
"""发送 chat 请求,处理文心特殊鉴权"""
# Resolve access token if using AK/SK
if self._access_key and self._secret_key and not self._api_key.startswith("pkf"):
await self._ensure_access_token()
if self._access_token:
self._api_key = self._access_token
# Map model name
request.model = WENXIN_MODEL_MAP.get(request.model, request.model)
return await super().chat(request)
async def _ensure_access_token(self) -> None:
"""确保 access_token 有效(缓存 29 天)"""
if self._access_token and time.time() < self._token_expires_at:
return
try:
import httpx
url = (
f"https://aip.baidubce.com/oauth/2.0/token?"
f"grant_type=client_credentials&client_id={self._access_key}"
f"&client_secret={self._secret_key}"
)
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(url)
data = response.json()
if "access_token" in data:
self._access_token = data["access_token"]
# Cache for 29 days (token valid for 30 days)
self._token_expires_at = time.time() + 29 * 86400
logger.info("Wenxin access token refreshed")
else:
logger.error(f"Failed to get Wenxin access token: {data}")
except Exception as e:
logger.error(f"Wenxin token refresh failed: {e}")

View File

@ -0,0 +1,71 @@
"""YuanbaoProvider - 腾讯混元/元宝 Provider
支持 Hunyuan 2.0/T1 系列模型
API腾讯云 OpenAI 兼容接口
鉴权Bearer API Key
"""
from __future__ import annotations
import logging
from typing import Any
from agentkit.llm.providers.openai import OpenAICompatibleProvider
from agentkit.llm.protocol import LLMRequest, LLMResponse
logger = logging.getLogger(__name__)
# 混元模型映射
YUANBAO_MODEL_MAP = {
"hunyuan-turbos-latest": "hunyuan-turbos-latest",
"hunyuan-2.0": "hunyuan-2.0",
"hunyuan-t1": "hunyuan-t1",
"hunyuan-vision-1.5": "hunyuan-vision-1.5",
}
# 腾讯混元 API base URL
YUANBAO_DEFAULT_BASE_URL = "https://api.hunyuan.cloud.tencent.com/v1"
class YuanbaoProvider(OpenAICompatibleProvider):
"""腾讯混元/元宝 Provider
通过腾讯云 OpenAI 兼容接口调用混元模型
使用方式
provider = YuanbaoProvider(
api_key="your_hunyuan_api_key",
default_model="hunyuan-turbos-latest",
)
特殊参数
- enable_enhancement: 增强模式通过 LLMRequest._extra 传递
"""
def __init__(
self,
api_key: str,
base_url: str = YUANBAO_DEFAULT_BASE_URL,
default_model: str = "hunyuan-turbos-latest",
enable_enhancement: bool = False,
**kwargs: Any,
):
self._enable_enhancement = enable_enhancement
super().__init__(
api_key=api_key,
base_url=base_url,
default_model=default_model,
**kwargs,
)
async def chat(self, request: LLMRequest) -> LLMResponse:
"""发送 chat 请求,处理混元模型映射和增强模式"""
request.model = YUANBAO_MODEL_MAP.get(request.model, request.model)
# Add enhancement parameter if enabled
if self._enable_enhancement:
if not hasattr(request, "_extra") or request._extra is None:
request._extra = {}
request._extra["enable_enhancement"] = True
return await super().chat(request)

View File

@ -0,0 +1,120 @@
"""Tests for Chinese LLM Providers (Wenxin, Doubao, Yuanbao)"""
import pytest
from agentkit.llm.providers.wenxin import WenxinProvider, WENXIN_MODEL_MAP
from agentkit.llm.providers.doubao import DoubaoProvider, DOUBAO_MODEL_MAP
from agentkit.llm.providers.yuanbao import YuanbaoProvider, YUANBAO_MODEL_MAP
from agentkit.llm.protocol import LLMRequest
class TestWenxinProvider:
"""WenxinProvider unit tests"""
def test_init_with_api_key(self):
provider = WenxinProvider(api_key="test_key")
assert provider._api_key == "test_key"
assert provider._default_model == "ernie-4.5-turbo-128k"
def test_init_with_ak_sk(self):
provider = WenxinProvider(
api_key="",
access_key="test_ak",
secret_key="test_sk",
)
assert provider._access_key == "test_ak"
assert provider._secret_key == "test_sk"
def test_model_mapping(self):
assert "ernie-4.5-turbo-128k" in WENXIN_MODEL_MAP
assert "ernie-5.0" in WENXIN_MODEL_MAP
assert "ernie-x1.1" in WENXIN_MODEL_MAP
def test_default_base_url(self):
from agentkit.llm.providers.wenxin import WENXIN_DEFAULT_BASE_URL
assert "qianfan.baidubce.com" in WENXIN_DEFAULT_BASE_URL
def test_custom_base_url(self):
provider = WenxinProvider(api_key="test", base_url="https://custom.api.com/v2")
assert "custom.api.com" in provider._base_url
class TestDoubaoProvider:
"""DoubaoProvider unit tests"""
def test_init(self):
provider = DoubaoProvider(api_key="test_key")
assert provider._api_key == "test_key"
assert provider._default_model == "doubao-pro-32k"
def test_model_mapping(self):
assert "doubao-pro-32k" in DOUBAO_MODEL_MAP
assert "doubao-lite-32k" in DOUBAO_MODEL_MAP
assert "doubao-vision" in DOUBAO_MODEL_MAP
def test_default_base_url(self):
from agentkit.llm.providers.doubao import DOUBAO_DEFAULT_BASE_URL
assert "ark.cn-beijing.volces.com" in DOUBAO_DEFAULT_BASE_URL
def test_custom_model(self):
provider = DoubaoProvider(
api_key="test",
default_model="doubao-lite-32k",
)
assert provider._default_model == "doubao-lite-32k"
class TestYuanbaoProvider:
"""YuanbaoProvider unit tests"""
def test_init(self):
provider = YuanbaoProvider(api_key="test_key")
assert provider._api_key == "test_key"
assert provider._default_model == "hunyuan-turbos-latest"
def test_init_with_enhancement(self):
provider = YuanbaoProvider(api_key="test", enable_enhancement=True)
assert provider._enable_enhancement is True
def test_model_mapping(self):
assert "hunyuan-turbos-latest" in YUANBAO_MODEL_MAP
assert "hunyuan-2.0" in YUANBAO_MODEL_MAP
assert "hunyuan-t1" in YUANBAO_MODEL_MAP
def test_default_base_url(self):
from agentkit.llm.providers.yuanbao import YUANBAO_DEFAULT_BASE_URL
assert "hunyuan.cloud.tencent.com" in YUANBAO_DEFAULT_BASE_URL
def test_enhancement_disabled_by_default(self):
provider = YuanbaoProvider(api_key="test")
assert provider._enable_enhancement is False
class TestProviderImports:
"""Test that all providers are importable from the package"""
def test_import_all_providers(self):
from agentkit.llm.providers import (
AnthropicProvider,
DoubaoProvider,
GeminiProvider,
OpenAICompatibleProvider,
WenxinProvider,
YuanbaoProvider,
)
assert AnthropicProvider is not None
assert DoubaoProvider is not None
assert GeminiProvider is not None
assert OpenAICompatibleProvider is not None
assert WenxinProvider is not None
assert YuanbaoProvider is not None
def test_inheritance(self):
"""All providers should inherit from OpenAICompatibleProvider or LLMProvider"""
from agentkit.llm.providers.openai import OpenAICompatibleProvider
from agentkit.llm.protocol import LLMProvider
assert issubclass(WenxinProvider, OpenAICompatibleProvider)
assert issubclass(DoubaoProvider, OpenAICompatibleProvider)
assert issubclass(YuanbaoProvider, OpenAICompatibleProvider)
assert issubclass(OpenAICompatibleProvider, LLMProvider)