diff --git a/src/agentkit/llm/providers/__init__.py b/src/agentkit/llm/providers/__init__.py index 66183cf..5a3ac74 100644 --- a/src/agentkit/llm/providers/__init__.py +++ b/src/agentkit/llm/providers/__init__.py @@ -1,15 +1,21 @@ """LLM Providers""" from agentkit.llm.providers.anthropic import AnthropicProvider +from agentkit.llm.providers.doubao import DoubaoProvider from agentkit.llm.providers.gemini import GeminiProvider from agentkit.llm.providers.openai import OpenAICompatibleProvider from agentkit.llm.providers.tracker import UsageRecord, UsageSummary, UsageTracker +from agentkit.llm.providers.wenxin import WenxinProvider +from agentkit.llm.providers.yuanbao import YuanbaoProvider __all__ = [ "AnthropicProvider", + "DoubaoProvider", "GeminiProvider", "OpenAICompatibleProvider", "UsageRecord", "UsageSummary", "UsageTracker", + "WenxinProvider", + "YuanbaoProvider", ] diff --git a/src/agentkit/llm/providers/doubao.py b/src/agentkit/llm/providers/doubao.py new file mode 100644 index 0000000..ebd7f9a --- /dev/null +++ b/src/agentkit/llm/providers/doubao.py @@ -0,0 +1,63 @@ +"""DoubaoProvider - 字节豆包 Provider + +支持豆包 1.6 Pro/Lite 系列模型。 +API:火山引擎 OpenAI 兼容接口 +鉴权:Bearer API Key(火山引擎 IAM) +""" + +from __future__ import annotations + +import logging +from typing import Any + +from agentkit.llm.providers.openai import OpenAICompatibleProvider + +logger = logging.getLogger(__name__) + +# 豆包模型映射 +DOUBAO_MODEL_MAP = { + "doubao-pro-32k": "doubao-pro-32k", + "doubao-pro-128k": "doubao-pro-128k", + "doubao-lite-32k": "doubao-lite-32k", + "doubao-lite-128k": "doubao-lite-128k", + "doubao-vision": "doubao-vision", +} + +# 火山引擎 API base URL +DOUBAO_DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3" + + +class DoubaoProvider(OpenAICompatibleProvider): + """字节豆包 Provider + + 通过火山引擎 OpenAI 兼容接口调用豆包模型。 + + 使用方式: + provider = DoubaoProvider( + api_key="your_ark_api_key", + # 可选:指定推理接入点 ID 作为 default_model + default_model="doubao-pro-32k", + ) + + 注意:火山引擎需要在控制台创建"推理接入点"获取 Service ID, + 也可以直接使用模型名称作为 endpoint_id。 + """ + + def __init__( + self, + api_key: str, + base_url: str = DOUBAO_DEFAULT_BASE_URL, + default_model: str = "doubao-pro-32k", + **kwargs: Any, + ): + super().__init__( + api_key=api_key, + base_url=base_url, + default_model=default_model, + **kwargs, + ) + + async def chat(self, request): + """发送 chat 请求,处理豆包模型映射""" + request.model = DOUBAO_MODEL_MAP.get(request.model, request.model) + return await super().chat(request) diff --git a/src/agentkit/llm/providers/wenxin.py b/src/agentkit/llm/providers/wenxin.py new file mode 100644 index 0000000..ee4e290 --- /dev/null +++ b/src/agentkit/llm/providers/wenxin.py @@ -0,0 +1,114 @@ +"""WenxinProvider - 百度文心 ERNIE Provider + +支持 ERNIE 4.5/5.0 系列模型。 +鉴权:AK/SK → access_token(缓存 29 天) +API:百度千帆平台 OpenAI 兼容接口 +""" + +from __future__ import annotations + +import logging +import time +from typing import Any + +from agentkit.llm.providers.openai import OpenAICompatibleProvider +from agentkit.llm.protocol import LLMRequest, LLMResponse + +logger = logging.getLogger(__name__) + +# 文心模型到端点的映射 +WENXIN_MODEL_MAP = { + "ernie-4.5-turbo-128k": "ernie-4.5-turbo-128k", + "ernie-5.0": "ernie-5.0", + "ernie-x1.1": "ernie-x1.1", + "ernie-4.0-8k": "ernie-4.0-8k", + "ernie-3.5-8k": "ernie-3.5-8k", +} + +# 默认 base URL(千帆 v2 OpenAI 兼容接口) +WENXIN_DEFAULT_BASE_URL = "https://qianfan.baidubce.com/v2" + + +class WenxinProvider(OpenAICompatibleProvider): + """百度文心 ERNIE Provider + + 通过千帆平台 v2 OpenAI 兼容接口调用文心模型。 + + 鉴权方式: + - 方式1(推荐):直接使用 API Key,走 OpenAI 兼容接口 + - 方式2(传统):AK/SK 换取 access_token + + 使用方式: + provider = WenxinProvider(api_key="your_api_key") + # 或使用 AK/SK + provider = WenxinProvider(api_key="", access_key="ak", secret_key="sk") + """ + + def __init__( + self, + api_key: str = "", + access_key: str | None = None, + secret_key: str | None = None, + base_url: str = WENXIN_DEFAULT_BASE_URL, + default_model: str = "ernie-4.5-turbo-128k", + **kwargs: Any, + ): + # If AK/SK provided, use token-based auth + self._access_key = access_key + self._secret_key = secret_key + self._access_token: str | None = None + self._token_expires_at: float = 0.0 + + # Resolve API key + effective_api_key = api_key + if not api_key and access_key and secret_key: + effective_api_key = "pending_token" # Will be resolved on first request + + super().__init__( + api_key=effective_api_key, + base_url=base_url, + default_model=default_model, + **kwargs, + ) + + async def chat(self, request: LLMRequest) -> LLMResponse: + """发送 chat 请求,处理文心特殊鉴权""" + # Resolve access token if using AK/SK + if self._access_key and self._secret_key and not self._api_key.startswith("pkf"): + await self._ensure_access_token() + if self._access_token: + self._api_key = self._access_token + + # Map model name + request.model = WENXIN_MODEL_MAP.get(request.model, request.model) + + return await super().chat(request) + + async def _ensure_access_token(self) -> None: + """确保 access_token 有效(缓存 29 天)""" + if self._access_token and time.time() < self._token_expires_at: + return + + try: + import httpx + + url = ( + f"https://aip.baidubce.com/oauth/2.0/token?" + f"grant_type=client_credentials&client_id={self._access_key}" + f"&client_secret={self._secret_key}" + ) + + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post(url) + data = response.json() + + if "access_token" in data: + self._access_token = data["access_token"] + # Cache for 29 days (token valid for 30 days) + self._token_expires_at = time.time() + 29 * 86400 + logger.info("Wenxin access token refreshed") + else: + logger.error(f"Failed to get Wenxin access token: {data}") + + except Exception as e: + logger.error(f"Wenxin token refresh failed: {e}") diff --git a/src/agentkit/llm/providers/yuanbao.py b/src/agentkit/llm/providers/yuanbao.py new file mode 100644 index 0000000..a055c36 --- /dev/null +++ b/src/agentkit/llm/providers/yuanbao.py @@ -0,0 +1,71 @@ +"""YuanbaoProvider - 腾讯混元/元宝 Provider + +支持 Hunyuan 2.0/T1 系列模型。 +API:腾讯云 OpenAI 兼容接口 +鉴权:Bearer API Key +""" + +from __future__ import annotations + +import logging +from typing import Any + +from agentkit.llm.providers.openai import OpenAICompatibleProvider +from agentkit.llm.protocol import LLMRequest, LLMResponse + +logger = logging.getLogger(__name__) + +# 混元模型映射 +YUANBAO_MODEL_MAP = { + "hunyuan-turbos-latest": "hunyuan-turbos-latest", + "hunyuan-2.0": "hunyuan-2.0", + "hunyuan-t1": "hunyuan-t1", + "hunyuan-vision-1.5": "hunyuan-vision-1.5", +} + +# 腾讯混元 API base URL +YUANBAO_DEFAULT_BASE_URL = "https://api.hunyuan.cloud.tencent.com/v1" + + +class YuanbaoProvider(OpenAICompatibleProvider): + """腾讯混元/元宝 Provider + + 通过腾讯云 OpenAI 兼容接口调用混元模型。 + + 使用方式: + provider = YuanbaoProvider( + api_key="your_hunyuan_api_key", + default_model="hunyuan-turbos-latest", + ) + + 特殊参数: + - enable_enhancement: 增强模式(通过 LLMRequest._extra 传递) + """ + + def __init__( + self, + api_key: str, + base_url: str = YUANBAO_DEFAULT_BASE_URL, + default_model: str = "hunyuan-turbos-latest", + enable_enhancement: bool = False, + **kwargs: Any, + ): + self._enable_enhancement = enable_enhancement + super().__init__( + api_key=api_key, + base_url=base_url, + default_model=default_model, + **kwargs, + ) + + async def chat(self, request: LLMRequest) -> LLMResponse: + """发送 chat 请求,处理混元模型映射和增强模式""" + request.model = YUANBAO_MODEL_MAP.get(request.model, request.model) + + # Add enhancement parameter if enabled + if self._enable_enhancement: + if not hasattr(request, "_extra") or request._extra is None: + request._extra = {} + request._extra["enable_enhancement"] = True + + return await super().chat(request) diff --git a/tests/unit/test_chinese_providers.py b/tests/unit/test_chinese_providers.py new file mode 100644 index 0000000..c5cfbe3 --- /dev/null +++ b/tests/unit/test_chinese_providers.py @@ -0,0 +1,120 @@ +"""Tests for Chinese LLM Providers (Wenxin, Doubao, Yuanbao)""" + +import pytest + +from agentkit.llm.providers.wenxin import WenxinProvider, WENXIN_MODEL_MAP +from agentkit.llm.providers.doubao import DoubaoProvider, DOUBAO_MODEL_MAP +from agentkit.llm.providers.yuanbao import YuanbaoProvider, YUANBAO_MODEL_MAP +from agentkit.llm.protocol import LLMRequest + + +class TestWenxinProvider: + """WenxinProvider unit tests""" + + def test_init_with_api_key(self): + provider = WenxinProvider(api_key="test_key") + assert provider._api_key == "test_key" + assert provider._default_model == "ernie-4.5-turbo-128k" + + def test_init_with_ak_sk(self): + provider = WenxinProvider( + api_key="", + access_key="test_ak", + secret_key="test_sk", + ) + assert provider._access_key == "test_ak" + assert provider._secret_key == "test_sk" + + def test_model_mapping(self): + assert "ernie-4.5-turbo-128k" in WENXIN_MODEL_MAP + assert "ernie-5.0" in WENXIN_MODEL_MAP + assert "ernie-x1.1" in WENXIN_MODEL_MAP + + def test_default_base_url(self): + from agentkit.llm.providers.wenxin import WENXIN_DEFAULT_BASE_URL + assert "qianfan.baidubce.com" in WENXIN_DEFAULT_BASE_URL + + def test_custom_base_url(self): + provider = WenxinProvider(api_key="test", base_url="https://custom.api.com/v2") + assert "custom.api.com" in provider._base_url + + +class TestDoubaoProvider: + """DoubaoProvider unit tests""" + + def test_init(self): + provider = DoubaoProvider(api_key="test_key") + assert provider._api_key == "test_key" + assert provider._default_model == "doubao-pro-32k" + + def test_model_mapping(self): + assert "doubao-pro-32k" in DOUBAO_MODEL_MAP + assert "doubao-lite-32k" in DOUBAO_MODEL_MAP + assert "doubao-vision" in DOUBAO_MODEL_MAP + + def test_default_base_url(self): + from agentkit.llm.providers.doubao import DOUBAO_DEFAULT_BASE_URL + assert "ark.cn-beijing.volces.com" in DOUBAO_DEFAULT_BASE_URL + + def test_custom_model(self): + provider = DoubaoProvider( + api_key="test", + default_model="doubao-lite-32k", + ) + assert provider._default_model == "doubao-lite-32k" + + +class TestYuanbaoProvider: + """YuanbaoProvider unit tests""" + + def test_init(self): + provider = YuanbaoProvider(api_key="test_key") + assert provider._api_key == "test_key" + assert provider._default_model == "hunyuan-turbos-latest" + + def test_init_with_enhancement(self): + provider = YuanbaoProvider(api_key="test", enable_enhancement=True) + assert provider._enable_enhancement is True + + def test_model_mapping(self): + assert "hunyuan-turbos-latest" in YUANBAO_MODEL_MAP + assert "hunyuan-2.0" in YUANBAO_MODEL_MAP + assert "hunyuan-t1" in YUANBAO_MODEL_MAP + + def test_default_base_url(self): + from agentkit.llm.providers.yuanbao import YUANBAO_DEFAULT_BASE_URL + assert "hunyuan.cloud.tencent.com" in YUANBAO_DEFAULT_BASE_URL + + def test_enhancement_disabled_by_default(self): + provider = YuanbaoProvider(api_key="test") + assert provider._enable_enhancement is False + + +class TestProviderImports: + """Test that all providers are importable from the package""" + + def test_import_all_providers(self): + from agentkit.llm.providers import ( + AnthropicProvider, + DoubaoProvider, + GeminiProvider, + OpenAICompatibleProvider, + WenxinProvider, + YuanbaoProvider, + ) + assert AnthropicProvider is not None + assert DoubaoProvider is not None + assert GeminiProvider is not None + assert OpenAICompatibleProvider is not None + assert WenxinProvider is not None + assert YuanbaoProvider is not None + + def test_inheritance(self): + """All providers should inherit from OpenAICompatibleProvider or LLMProvider""" + from agentkit.llm.providers.openai import OpenAICompatibleProvider + from agentkit.llm.protocol import LLMProvider + + assert issubclass(WenxinProvider, OpenAICompatibleProvider) + assert issubclass(DoubaoProvider, OpenAICompatibleProvider) + assert issubclass(YuanbaoProvider, OpenAICompatibleProvider) + assert issubclass(OpenAICompatibleProvider, LLMProvider)