feat(llm): U8 Chinese LLM providers - Wenxin, Doubao, Yuanbao
- WenxinProvider: Baidu ERNIE via Qianfan v2 OpenAI-compatible API, AK/SK token auth - DoubaoProvider: ByteDance Doubao via Volcengine Ark API - YuanbaoProvider: Tencent Hunyuan via OpenAI-compatible API with enhancement mode - All inherit from OpenAICompatibleProvider for retry/circuit breaker support - 16 tests passing
This commit is contained in:
parent
34e083abde
commit
9753a08ac8
|
|
@ -1,15 +1,21 @@
|
|||
"""LLM Providers"""
|
||||
|
||||
from agentkit.llm.providers.anthropic import AnthropicProvider
|
||||
from agentkit.llm.providers.doubao import DoubaoProvider
|
||||
from agentkit.llm.providers.gemini import GeminiProvider
|
||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||
from agentkit.llm.providers.tracker import UsageRecord, UsageSummary, UsageTracker
|
||||
from agentkit.llm.providers.wenxin import WenxinProvider
|
||||
from agentkit.llm.providers.yuanbao import YuanbaoProvider
|
||||
|
||||
__all__ = [
|
||||
"AnthropicProvider",
|
||||
"DoubaoProvider",
|
||||
"GeminiProvider",
|
||||
"OpenAICompatibleProvider",
|
||||
"UsageRecord",
|
||||
"UsageSummary",
|
||||
"UsageTracker",
|
||||
"WenxinProvider",
|
||||
"YuanbaoProvider",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,63 @@
|
|||
"""DoubaoProvider - 字节豆包 Provider
|
||||
|
||||
支持豆包 1.6 Pro/Lite 系列模型。
|
||||
API:火山引擎 OpenAI 兼容接口
|
||||
鉴权:Bearer API Key(火山引擎 IAM)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 豆包模型映射
|
||||
DOUBAO_MODEL_MAP = {
|
||||
"doubao-pro-32k": "doubao-pro-32k",
|
||||
"doubao-pro-128k": "doubao-pro-128k",
|
||||
"doubao-lite-32k": "doubao-lite-32k",
|
||||
"doubao-lite-128k": "doubao-lite-128k",
|
||||
"doubao-vision": "doubao-vision",
|
||||
}
|
||||
|
||||
# 火山引擎 API base URL
|
||||
DOUBAO_DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
|
||||
|
||||
class DoubaoProvider(OpenAICompatibleProvider):
|
||||
"""字节豆包 Provider
|
||||
|
||||
通过火山引擎 OpenAI 兼容接口调用豆包模型。
|
||||
|
||||
使用方式:
|
||||
provider = DoubaoProvider(
|
||||
api_key="your_ark_api_key",
|
||||
# 可选:指定推理接入点 ID 作为 default_model
|
||||
default_model="doubao-pro-32k",
|
||||
)
|
||||
|
||||
注意:火山引擎需要在控制台创建"推理接入点"获取 Service ID,
|
||||
也可以直接使用模型名称作为 endpoint_id。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = DOUBAO_DEFAULT_BASE_URL,
|
||||
default_model: str = "doubao-pro-32k",
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
default_model=default_model,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def chat(self, request):
|
||||
"""发送 chat 请求,处理豆包模型映射"""
|
||||
request.model = DOUBAO_MODEL_MAP.get(request.model, request.model)
|
||||
return await super().chat(request)
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
"""WenxinProvider - 百度文心 ERNIE Provider
|
||||
|
||||
支持 ERNIE 4.5/5.0 系列模型。
|
||||
鉴权:AK/SK → access_token(缓存 29 天)
|
||||
API:百度千帆平台 OpenAI 兼容接口
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||
from agentkit.llm.protocol import LLMRequest, LLMResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 文心模型到端点的映射
|
||||
WENXIN_MODEL_MAP = {
|
||||
"ernie-4.5-turbo-128k": "ernie-4.5-turbo-128k",
|
||||
"ernie-5.0": "ernie-5.0",
|
||||
"ernie-x1.1": "ernie-x1.1",
|
||||
"ernie-4.0-8k": "ernie-4.0-8k",
|
||||
"ernie-3.5-8k": "ernie-3.5-8k",
|
||||
}
|
||||
|
||||
# 默认 base URL(千帆 v2 OpenAI 兼容接口)
|
||||
WENXIN_DEFAULT_BASE_URL = "https://qianfan.baidubce.com/v2"
|
||||
|
||||
|
||||
class WenxinProvider(OpenAICompatibleProvider):
|
||||
"""百度文心 ERNIE Provider
|
||||
|
||||
通过千帆平台 v2 OpenAI 兼容接口调用文心模型。
|
||||
|
||||
鉴权方式:
|
||||
- 方式1(推荐):直接使用 API Key,走 OpenAI 兼容接口
|
||||
- 方式2(传统):AK/SK 换取 access_token
|
||||
|
||||
使用方式:
|
||||
provider = WenxinProvider(api_key="your_api_key")
|
||||
# 或使用 AK/SK
|
||||
provider = WenxinProvider(api_key="", access_key="ak", secret_key="sk")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = "",
|
||||
access_key: str | None = None,
|
||||
secret_key: str | None = None,
|
||||
base_url: str = WENXIN_DEFAULT_BASE_URL,
|
||||
default_model: str = "ernie-4.5-turbo-128k",
|
||||
**kwargs: Any,
|
||||
):
|
||||
# If AK/SK provided, use token-based auth
|
||||
self._access_key = access_key
|
||||
self._secret_key = secret_key
|
||||
self._access_token: str | None = None
|
||||
self._token_expires_at: float = 0.0
|
||||
|
||||
# Resolve API key
|
||||
effective_api_key = api_key
|
||||
if not api_key and access_key and secret_key:
|
||||
effective_api_key = "pending_token" # Will be resolved on first request
|
||||
|
||||
super().__init__(
|
||||
api_key=effective_api_key,
|
||||
base_url=base_url,
|
||||
default_model=default_model,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||
"""发送 chat 请求,处理文心特殊鉴权"""
|
||||
# Resolve access token if using AK/SK
|
||||
if self._access_key and self._secret_key and not self._api_key.startswith("pkf"):
|
||||
await self._ensure_access_token()
|
||||
if self._access_token:
|
||||
self._api_key = self._access_token
|
||||
|
||||
# Map model name
|
||||
request.model = WENXIN_MODEL_MAP.get(request.model, request.model)
|
||||
|
||||
return await super().chat(request)
|
||||
|
||||
async def _ensure_access_token(self) -> None:
|
||||
"""确保 access_token 有效(缓存 29 天)"""
|
||||
if self._access_token and time.time() < self._token_expires_at:
|
||||
return
|
||||
|
||||
try:
|
||||
import httpx
|
||||
|
||||
url = (
|
||||
f"https://aip.baidubce.com/oauth/2.0/token?"
|
||||
f"grant_type=client_credentials&client_id={self._access_key}"
|
||||
f"&client_secret={self._secret_key}"
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(url)
|
||||
data = response.json()
|
||||
|
||||
if "access_token" in data:
|
||||
self._access_token = data["access_token"]
|
||||
# Cache for 29 days (token valid for 30 days)
|
||||
self._token_expires_at = time.time() + 29 * 86400
|
||||
logger.info("Wenxin access token refreshed")
|
||||
else:
|
||||
logger.error(f"Failed to get Wenxin access token: {data}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Wenxin token refresh failed: {e}")
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
"""YuanbaoProvider - 腾讯混元/元宝 Provider
|
||||
|
||||
支持 Hunyuan 2.0/T1 系列模型。
|
||||
API:腾讯云 OpenAI 兼容接口
|
||||
鉴权:Bearer API Key
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||
from agentkit.llm.protocol import LLMRequest, LLMResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 混元模型映射
|
||||
YUANBAO_MODEL_MAP = {
|
||||
"hunyuan-turbos-latest": "hunyuan-turbos-latest",
|
||||
"hunyuan-2.0": "hunyuan-2.0",
|
||||
"hunyuan-t1": "hunyuan-t1",
|
||||
"hunyuan-vision-1.5": "hunyuan-vision-1.5",
|
||||
}
|
||||
|
||||
# 腾讯混元 API base URL
|
||||
YUANBAO_DEFAULT_BASE_URL = "https://api.hunyuan.cloud.tencent.com/v1"
|
||||
|
||||
|
||||
class YuanbaoProvider(OpenAICompatibleProvider):
|
||||
"""腾讯混元/元宝 Provider
|
||||
|
||||
通过腾讯云 OpenAI 兼容接口调用混元模型。
|
||||
|
||||
使用方式:
|
||||
provider = YuanbaoProvider(
|
||||
api_key="your_hunyuan_api_key",
|
||||
default_model="hunyuan-turbos-latest",
|
||||
)
|
||||
|
||||
特殊参数:
|
||||
- enable_enhancement: 增强模式(通过 LLMRequest._extra 传递)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = YUANBAO_DEFAULT_BASE_URL,
|
||||
default_model: str = "hunyuan-turbos-latest",
|
||||
enable_enhancement: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._enable_enhancement = enable_enhancement
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
default_model=default_model,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||
"""发送 chat 请求,处理混元模型映射和增强模式"""
|
||||
request.model = YUANBAO_MODEL_MAP.get(request.model, request.model)
|
||||
|
||||
# Add enhancement parameter if enabled
|
||||
if self._enable_enhancement:
|
||||
if not hasattr(request, "_extra") or request._extra is None:
|
||||
request._extra = {}
|
||||
request._extra["enable_enhancement"] = True
|
||||
|
||||
return await super().chat(request)
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
"""Tests for Chinese LLM Providers (Wenxin, Doubao, Yuanbao)"""
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.llm.providers.wenxin import WenxinProvider, WENXIN_MODEL_MAP
|
||||
from agentkit.llm.providers.doubao import DoubaoProvider, DOUBAO_MODEL_MAP
|
||||
from agentkit.llm.providers.yuanbao import YuanbaoProvider, YUANBAO_MODEL_MAP
|
||||
from agentkit.llm.protocol import LLMRequest
|
||||
|
||||
|
||||
class TestWenxinProvider:
|
||||
"""WenxinProvider unit tests"""
|
||||
|
||||
def test_init_with_api_key(self):
|
||||
provider = WenxinProvider(api_key="test_key")
|
||||
assert provider._api_key == "test_key"
|
||||
assert provider._default_model == "ernie-4.5-turbo-128k"
|
||||
|
||||
def test_init_with_ak_sk(self):
|
||||
provider = WenxinProvider(
|
||||
api_key="",
|
||||
access_key="test_ak",
|
||||
secret_key="test_sk",
|
||||
)
|
||||
assert provider._access_key == "test_ak"
|
||||
assert provider._secret_key == "test_sk"
|
||||
|
||||
def test_model_mapping(self):
|
||||
assert "ernie-4.5-turbo-128k" in WENXIN_MODEL_MAP
|
||||
assert "ernie-5.0" in WENXIN_MODEL_MAP
|
||||
assert "ernie-x1.1" in WENXIN_MODEL_MAP
|
||||
|
||||
def test_default_base_url(self):
|
||||
from agentkit.llm.providers.wenxin import WENXIN_DEFAULT_BASE_URL
|
||||
assert "qianfan.baidubce.com" in WENXIN_DEFAULT_BASE_URL
|
||||
|
||||
def test_custom_base_url(self):
|
||||
provider = WenxinProvider(api_key="test", base_url="https://custom.api.com/v2")
|
||||
assert "custom.api.com" in provider._base_url
|
||||
|
||||
|
||||
class TestDoubaoProvider:
|
||||
"""DoubaoProvider unit tests"""
|
||||
|
||||
def test_init(self):
|
||||
provider = DoubaoProvider(api_key="test_key")
|
||||
assert provider._api_key == "test_key"
|
||||
assert provider._default_model == "doubao-pro-32k"
|
||||
|
||||
def test_model_mapping(self):
|
||||
assert "doubao-pro-32k" in DOUBAO_MODEL_MAP
|
||||
assert "doubao-lite-32k" in DOUBAO_MODEL_MAP
|
||||
assert "doubao-vision" in DOUBAO_MODEL_MAP
|
||||
|
||||
def test_default_base_url(self):
|
||||
from agentkit.llm.providers.doubao import DOUBAO_DEFAULT_BASE_URL
|
||||
assert "ark.cn-beijing.volces.com" in DOUBAO_DEFAULT_BASE_URL
|
||||
|
||||
def test_custom_model(self):
|
||||
provider = DoubaoProvider(
|
||||
api_key="test",
|
||||
default_model="doubao-lite-32k",
|
||||
)
|
||||
assert provider._default_model == "doubao-lite-32k"
|
||||
|
||||
|
||||
class TestYuanbaoProvider:
|
||||
"""YuanbaoProvider unit tests"""
|
||||
|
||||
def test_init(self):
|
||||
provider = YuanbaoProvider(api_key="test_key")
|
||||
assert provider._api_key == "test_key"
|
||||
assert provider._default_model == "hunyuan-turbos-latest"
|
||||
|
||||
def test_init_with_enhancement(self):
|
||||
provider = YuanbaoProvider(api_key="test", enable_enhancement=True)
|
||||
assert provider._enable_enhancement is True
|
||||
|
||||
def test_model_mapping(self):
|
||||
assert "hunyuan-turbos-latest" in YUANBAO_MODEL_MAP
|
||||
assert "hunyuan-2.0" in YUANBAO_MODEL_MAP
|
||||
assert "hunyuan-t1" in YUANBAO_MODEL_MAP
|
||||
|
||||
def test_default_base_url(self):
|
||||
from agentkit.llm.providers.yuanbao import YUANBAO_DEFAULT_BASE_URL
|
||||
assert "hunyuan.cloud.tencent.com" in YUANBAO_DEFAULT_BASE_URL
|
||||
|
||||
def test_enhancement_disabled_by_default(self):
|
||||
provider = YuanbaoProvider(api_key="test")
|
||||
assert provider._enable_enhancement is False
|
||||
|
||||
|
||||
class TestProviderImports:
|
||||
"""Test that all providers are importable from the package"""
|
||||
|
||||
def test_import_all_providers(self):
|
||||
from agentkit.llm.providers import (
|
||||
AnthropicProvider,
|
||||
DoubaoProvider,
|
||||
GeminiProvider,
|
||||
OpenAICompatibleProvider,
|
||||
WenxinProvider,
|
||||
YuanbaoProvider,
|
||||
)
|
||||
assert AnthropicProvider is not None
|
||||
assert DoubaoProvider is not None
|
||||
assert GeminiProvider is not None
|
||||
assert OpenAICompatibleProvider is not None
|
||||
assert WenxinProvider is not None
|
||||
assert YuanbaoProvider is not None
|
||||
|
||||
def test_inheritance(self):
|
||||
"""All providers should inherit from OpenAICompatibleProvider or LLMProvider"""
|
||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||
from agentkit.llm.protocol import LLMProvider
|
||||
|
||||
assert issubclass(WenxinProvider, OpenAICompatibleProvider)
|
||||
assert issubclass(DoubaoProvider, OpenAICompatibleProvider)
|
||||
assert issubclass(YuanbaoProvider, OpenAICompatibleProvider)
|
||||
assert issubclass(OpenAICompatibleProvider, LLMProvider)
|
||||
Loading…
Reference in New Issue