feat(llm): U8 Chinese LLM providers - Wenxin, Doubao, Yuanbao
- WenxinProvider: Baidu ERNIE via Qianfan v2 OpenAI-compatible API, AK/SK token auth - DoubaoProvider: ByteDance Doubao via Volcengine Ark API - YuanbaoProvider: Tencent Hunyuan via OpenAI-compatible API with enhancement mode - All inherit from OpenAICompatibleProvider for retry/circuit breaker support - 16 tests passing
This commit is contained in:
parent
34e083abde
commit
9753a08ac8
|
|
@ -1,15 +1,21 @@
|
||||||
"""LLM Providers"""
|
"""LLM Providers"""
|
||||||
|
|
||||||
from agentkit.llm.providers.anthropic import AnthropicProvider
|
from agentkit.llm.providers.anthropic import AnthropicProvider
|
||||||
|
from agentkit.llm.providers.doubao import DoubaoProvider
|
||||||
from agentkit.llm.providers.gemini import GeminiProvider
|
from agentkit.llm.providers.gemini import GeminiProvider
|
||||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||||
from agentkit.llm.providers.tracker import UsageRecord, UsageSummary, UsageTracker
|
from agentkit.llm.providers.tracker import UsageRecord, UsageSummary, UsageTracker
|
||||||
|
from agentkit.llm.providers.wenxin import WenxinProvider
|
||||||
|
from agentkit.llm.providers.yuanbao import YuanbaoProvider
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnthropicProvider",
|
"AnthropicProvider",
|
||||||
|
"DoubaoProvider",
|
||||||
"GeminiProvider",
|
"GeminiProvider",
|
||||||
"OpenAICompatibleProvider",
|
"OpenAICompatibleProvider",
|
||||||
"UsageRecord",
|
"UsageRecord",
|
||||||
"UsageSummary",
|
"UsageSummary",
|
||||||
"UsageTracker",
|
"UsageTracker",
|
||||||
|
"WenxinProvider",
|
||||||
|
"YuanbaoProvider",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,63 @@
|
||||||
|
"""DoubaoProvider - 字节豆包 Provider
|
||||||
|
|
||||||
|
支持豆包 1.6 Pro/Lite 系列模型。
|
||||||
|
API:火山引擎 OpenAI 兼容接口
|
||||||
|
鉴权:Bearer API Key(火山引擎 IAM)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 豆包模型映射
|
||||||
|
DOUBAO_MODEL_MAP = {
|
||||||
|
"doubao-pro-32k": "doubao-pro-32k",
|
||||||
|
"doubao-pro-128k": "doubao-pro-128k",
|
||||||
|
"doubao-lite-32k": "doubao-lite-32k",
|
||||||
|
"doubao-lite-128k": "doubao-lite-128k",
|
||||||
|
"doubao-vision": "doubao-vision",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 火山引擎 API base URL
|
||||||
|
DOUBAO_DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
|
||||||
|
|
||||||
|
|
||||||
|
class DoubaoProvider(OpenAICompatibleProvider):
|
||||||
|
"""字节豆包 Provider
|
||||||
|
|
||||||
|
通过火山引擎 OpenAI 兼容接口调用豆包模型。
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
provider = DoubaoProvider(
|
||||||
|
api_key="your_ark_api_key",
|
||||||
|
# 可选:指定推理接入点 ID 作为 default_model
|
||||||
|
default_model="doubao-pro-32k",
|
||||||
|
)
|
||||||
|
|
||||||
|
注意:火山引擎需要在控制台创建"推理接入点"获取 Service ID,
|
||||||
|
也可以直接使用模型名称作为 endpoint_id。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str = DOUBAO_DEFAULT_BASE_URL,
|
||||||
|
default_model: str = "doubao-pro-32k",
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
default_model=default_model,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat(self, request):
|
||||||
|
"""发送 chat 请求,处理豆包模型映射"""
|
||||||
|
request.model = DOUBAO_MODEL_MAP.get(request.model, request.model)
|
||||||
|
return await super().chat(request)
|
||||||
|
|
@ -0,0 +1,114 @@
|
||||||
|
"""WenxinProvider - 百度文心 ERNIE Provider
|
||||||
|
|
||||||
|
支持 ERNIE 4.5/5.0 系列模型。
|
||||||
|
鉴权:AK/SK → access_token(缓存 29 天)
|
||||||
|
API:百度千帆平台 OpenAI 兼容接口
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||||
|
from agentkit.llm.protocol import LLMRequest, LLMResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 文心模型到端点的映射
|
||||||
|
WENXIN_MODEL_MAP = {
|
||||||
|
"ernie-4.5-turbo-128k": "ernie-4.5-turbo-128k",
|
||||||
|
"ernie-5.0": "ernie-5.0",
|
||||||
|
"ernie-x1.1": "ernie-x1.1",
|
||||||
|
"ernie-4.0-8k": "ernie-4.0-8k",
|
||||||
|
"ernie-3.5-8k": "ernie-3.5-8k",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 默认 base URL(千帆 v2 OpenAI 兼容接口)
|
||||||
|
WENXIN_DEFAULT_BASE_URL = "https://qianfan.baidubce.com/v2"
|
||||||
|
|
||||||
|
|
||||||
|
class WenxinProvider(OpenAICompatibleProvider):
|
||||||
|
"""百度文心 ERNIE Provider
|
||||||
|
|
||||||
|
通过千帆平台 v2 OpenAI 兼容接口调用文心模型。
|
||||||
|
|
||||||
|
鉴权方式:
|
||||||
|
- 方式1(推荐):直接使用 API Key,走 OpenAI 兼容接口
|
||||||
|
- 方式2(传统):AK/SK 换取 access_token
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
provider = WenxinProvider(api_key="your_api_key")
|
||||||
|
# 或使用 AK/SK
|
||||||
|
provider = WenxinProvider(api_key="", access_key="ak", secret_key="sk")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str = "",
|
||||||
|
access_key: str | None = None,
|
||||||
|
secret_key: str | None = None,
|
||||||
|
base_url: str = WENXIN_DEFAULT_BASE_URL,
|
||||||
|
default_model: str = "ernie-4.5-turbo-128k",
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
# If AK/SK provided, use token-based auth
|
||||||
|
self._access_key = access_key
|
||||||
|
self._secret_key = secret_key
|
||||||
|
self._access_token: str | None = None
|
||||||
|
self._token_expires_at: float = 0.0
|
||||||
|
|
||||||
|
# Resolve API key
|
||||||
|
effective_api_key = api_key
|
||||||
|
if not api_key and access_key and secret_key:
|
||||||
|
effective_api_key = "pending_token" # Will be resolved on first request
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
api_key=effective_api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
default_model=default_model,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||||
|
"""发送 chat 请求,处理文心特殊鉴权"""
|
||||||
|
# Resolve access token if using AK/SK
|
||||||
|
if self._access_key and self._secret_key and not self._api_key.startswith("pkf"):
|
||||||
|
await self._ensure_access_token()
|
||||||
|
if self._access_token:
|
||||||
|
self._api_key = self._access_token
|
||||||
|
|
||||||
|
# Map model name
|
||||||
|
request.model = WENXIN_MODEL_MAP.get(request.model, request.model)
|
||||||
|
|
||||||
|
return await super().chat(request)
|
||||||
|
|
||||||
|
async def _ensure_access_token(self) -> None:
|
||||||
|
"""确保 access_token 有效(缓存 29 天)"""
|
||||||
|
if self._access_token and time.time() < self._token_expires_at:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
url = (
|
||||||
|
f"https://aip.baidubce.com/oauth/2.0/token?"
|
||||||
|
f"grant_type=client_credentials&client_id={self._access_key}"
|
||||||
|
f"&client_secret={self._secret_key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.post(url)
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if "access_token" in data:
|
||||||
|
self._access_token = data["access_token"]
|
||||||
|
# Cache for 29 days (token valid for 30 days)
|
||||||
|
self._token_expires_at = time.time() + 29 * 86400
|
||||||
|
logger.info("Wenxin access token refreshed")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to get Wenxin access token: {data}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Wenxin token refresh failed: {e}")
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
"""YuanbaoProvider - 腾讯混元/元宝 Provider
|
||||||
|
|
||||||
|
支持 Hunyuan 2.0/T1 系列模型。
|
||||||
|
API:腾讯云 OpenAI 兼容接口
|
||||||
|
鉴权:Bearer API Key
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||||
|
from agentkit.llm.protocol import LLMRequest, LLMResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 混元模型映射
|
||||||
|
YUANBAO_MODEL_MAP = {
|
||||||
|
"hunyuan-turbos-latest": "hunyuan-turbos-latest",
|
||||||
|
"hunyuan-2.0": "hunyuan-2.0",
|
||||||
|
"hunyuan-t1": "hunyuan-t1",
|
||||||
|
"hunyuan-vision-1.5": "hunyuan-vision-1.5",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 腾讯混元 API base URL
|
||||||
|
YUANBAO_DEFAULT_BASE_URL = "https://api.hunyuan.cloud.tencent.com/v1"
|
||||||
|
|
||||||
|
|
||||||
|
class YuanbaoProvider(OpenAICompatibleProvider):
|
||||||
|
"""腾讯混元/元宝 Provider
|
||||||
|
|
||||||
|
通过腾讯云 OpenAI 兼容接口调用混元模型。
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
provider = YuanbaoProvider(
|
||||||
|
api_key="your_hunyuan_api_key",
|
||||||
|
default_model="hunyuan-turbos-latest",
|
||||||
|
)
|
||||||
|
|
||||||
|
特殊参数:
|
||||||
|
- enable_enhancement: 增强模式(通过 LLMRequest._extra 传递)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str = YUANBAO_DEFAULT_BASE_URL,
|
||||||
|
default_model: str = "hunyuan-turbos-latest",
|
||||||
|
enable_enhancement: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
self._enable_enhancement = enable_enhancement
|
||||||
|
super().__init__(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
default_model=default_model,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||||
|
"""发送 chat 请求,处理混元模型映射和增强模式"""
|
||||||
|
request.model = YUANBAO_MODEL_MAP.get(request.model, request.model)
|
||||||
|
|
||||||
|
# Add enhancement parameter if enabled
|
||||||
|
if self._enable_enhancement:
|
||||||
|
if not hasattr(request, "_extra") or request._extra is None:
|
||||||
|
request._extra = {}
|
||||||
|
request._extra["enable_enhancement"] = True
|
||||||
|
|
||||||
|
return await super().chat(request)
|
||||||
|
|
@ -0,0 +1,120 @@
|
||||||
|
"""Tests for Chinese LLM Providers (Wenxin, Doubao, Yuanbao)"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.llm.providers.wenxin import WenxinProvider, WENXIN_MODEL_MAP
|
||||||
|
from agentkit.llm.providers.doubao import DoubaoProvider, DOUBAO_MODEL_MAP
|
||||||
|
from agentkit.llm.providers.yuanbao import YuanbaoProvider, YUANBAO_MODEL_MAP
|
||||||
|
from agentkit.llm.protocol import LLMRequest
|
||||||
|
|
||||||
|
|
||||||
|
class TestWenxinProvider:
|
||||||
|
"""WenxinProvider unit tests"""
|
||||||
|
|
||||||
|
def test_init_with_api_key(self):
|
||||||
|
provider = WenxinProvider(api_key="test_key")
|
||||||
|
assert provider._api_key == "test_key"
|
||||||
|
assert provider._default_model == "ernie-4.5-turbo-128k"
|
||||||
|
|
||||||
|
def test_init_with_ak_sk(self):
|
||||||
|
provider = WenxinProvider(
|
||||||
|
api_key="",
|
||||||
|
access_key="test_ak",
|
||||||
|
secret_key="test_sk",
|
||||||
|
)
|
||||||
|
assert provider._access_key == "test_ak"
|
||||||
|
assert provider._secret_key == "test_sk"
|
||||||
|
|
||||||
|
def test_model_mapping(self):
|
||||||
|
assert "ernie-4.5-turbo-128k" in WENXIN_MODEL_MAP
|
||||||
|
assert "ernie-5.0" in WENXIN_MODEL_MAP
|
||||||
|
assert "ernie-x1.1" in WENXIN_MODEL_MAP
|
||||||
|
|
||||||
|
def test_default_base_url(self):
|
||||||
|
from agentkit.llm.providers.wenxin import WENXIN_DEFAULT_BASE_URL
|
||||||
|
assert "qianfan.baidubce.com" in WENXIN_DEFAULT_BASE_URL
|
||||||
|
|
||||||
|
def test_custom_base_url(self):
|
||||||
|
provider = WenxinProvider(api_key="test", base_url="https://custom.api.com/v2")
|
||||||
|
assert "custom.api.com" in provider._base_url
|
||||||
|
|
||||||
|
|
||||||
|
class TestDoubaoProvider:
|
||||||
|
"""DoubaoProvider unit tests"""
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
provider = DoubaoProvider(api_key="test_key")
|
||||||
|
assert provider._api_key == "test_key"
|
||||||
|
assert provider._default_model == "doubao-pro-32k"
|
||||||
|
|
||||||
|
def test_model_mapping(self):
|
||||||
|
assert "doubao-pro-32k" in DOUBAO_MODEL_MAP
|
||||||
|
assert "doubao-lite-32k" in DOUBAO_MODEL_MAP
|
||||||
|
assert "doubao-vision" in DOUBAO_MODEL_MAP
|
||||||
|
|
||||||
|
def test_default_base_url(self):
|
||||||
|
from agentkit.llm.providers.doubao import DOUBAO_DEFAULT_BASE_URL
|
||||||
|
assert "ark.cn-beijing.volces.com" in DOUBAO_DEFAULT_BASE_URL
|
||||||
|
|
||||||
|
def test_custom_model(self):
|
||||||
|
provider = DoubaoProvider(
|
||||||
|
api_key="test",
|
||||||
|
default_model="doubao-lite-32k",
|
||||||
|
)
|
||||||
|
assert provider._default_model == "doubao-lite-32k"
|
||||||
|
|
||||||
|
|
||||||
|
class TestYuanbaoProvider:
|
||||||
|
"""YuanbaoProvider unit tests"""
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
provider = YuanbaoProvider(api_key="test_key")
|
||||||
|
assert provider._api_key == "test_key"
|
||||||
|
assert provider._default_model == "hunyuan-turbos-latest"
|
||||||
|
|
||||||
|
def test_init_with_enhancement(self):
|
||||||
|
provider = YuanbaoProvider(api_key="test", enable_enhancement=True)
|
||||||
|
assert provider._enable_enhancement is True
|
||||||
|
|
||||||
|
def test_model_mapping(self):
|
||||||
|
assert "hunyuan-turbos-latest" in YUANBAO_MODEL_MAP
|
||||||
|
assert "hunyuan-2.0" in YUANBAO_MODEL_MAP
|
||||||
|
assert "hunyuan-t1" in YUANBAO_MODEL_MAP
|
||||||
|
|
||||||
|
def test_default_base_url(self):
|
||||||
|
from agentkit.llm.providers.yuanbao import YUANBAO_DEFAULT_BASE_URL
|
||||||
|
assert "hunyuan.cloud.tencent.com" in YUANBAO_DEFAULT_BASE_URL
|
||||||
|
|
||||||
|
def test_enhancement_disabled_by_default(self):
|
||||||
|
provider = YuanbaoProvider(api_key="test")
|
||||||
|
assert provider._enable_enhancement is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderImports:
|
||||||
|
"""Test that all providers are importable from the package"""
|
||||||
|
|
||||||
|
def test_import_all_providers(self):
|
||||||
|
from agentkit.llm.providers import (
|
||||||
|
AnthropicProvider,
|
||||||
|
DoubaoProvider,
|
||||||
|
GeminiProvider,
|
||||||
|
OpenAICompatibleProvider,
|
||||||
|
WenxinProvider,
|
||||||
|
YuanbaoProvider,
|
||||||
|
)
|
||||||
|
assert AnthropicProvider is not None
|
||||||
|
assert DoubaoProvider is not None
|
||||||
|
assert GeminiProvider is not None
|
||||||
|
assert OpenAICompatibleProvider is not None
|
||||||
|
assert WenxinProvider is not None
|
||||||
|
assert YuanbaoProvider is not None
|
||||||
|
|
||||||
|
def test_inheritance(self):
|
||||||
|
"""All providers should inherit from OpenAICompatibleProvider or LLMProvider"""
|
||||||
|
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||||
|
from agentkit.llm.protocol import LLMProvider
|
||||||
|
|
||||||
|
assert issubclass(WenxinProvider, OpenAICompatibleProvider)
|
||||||
|
assert issubclass(DoubaoProvider, OpenAICompatibleProvider)
|
||||||
|
assert issubclass(YuanbaoProvider, OpenAICompatibleProvider)
|
||||||
|
assert issubclass(OpenAICompatibleProvider, LLMProvider)
|
||||||
Loading…
Reference in New Issue