66 lines
2.0 KiB
Python
66 lines
2.0 KiB
Python
import os
|
||
|
||
from .base import LLMError, LLMProvider
|
||
from .openai_provider import OpenAIProvider
|
||
from .deepseek_provider import DeepSeekProvider
|
||
|
||
|
||
class LLMFactory:
|
||
"""LLM Provider工厂 - 统一创建和管理Provider实例"""
|
||
|
||
_providers: dict[str, type[LLMProvider]] = {}
|
||
|
||
@classmethod
|
||
def register(cls, name: str, provider_cls: type[LLMProvider]) -> None:
|
||
"""
|
||
注册Provider类
|
||
|
||
Args:
|
||
name: Provider名称(如 "openai", "deepseek")
|
||
provider_cls: LLMProvider子类
|
||
"""
|
||
cls._providers[name.lower()] = provider_cls
|
||
|
||
@classmethod
|
||
def create(cls, provider: str | None = None, model: str | None = None) -> LLMProvider:
|
||
"""
|
||
创建LLM Provider实例
|
||
|
||
Args:
|
||
provider: 提供商名称,如 "openai" | "deepseek"
|
||
为None时从环境变量 DEFAULT_LLM_PROVIDER 读取,默认 "openai"
|
||
model: 覆盖默认模型名
|
||
|
||
Returns:
|
||
LLMProvider实例
|
||
|
||
Raises:
|
||
LLMError: provider未注册时
|
||
"""
|
||
provider_name = (provider or os.getenv("DEFAULT_LLM_PROVIDER", "openai")).lower()
|
||
|
||
if provider_name not in cls._providers:
|
||
available = ", ".join(cls._providers.keys()) or "(无)"
|
||
raise LLMError(
|
||
f"未注册的Provider: {provider_name},可用: {available}",
|
||
provider="factory",
|
||
)
|
||
|
||
provider_cls = cls._providers[provider_name]
|
||
return provider_cls(model=model) if model else provider_cls()
|
||
|
||
@classmethod
|
||
def get_default(cls) -> LLMProvider:
|
||
"""获取默认provider(等价于 create())"""
|
||
return cls.create()
|
||
|
||
@classmethod
|
||
def list_providers(cls) -> list[str]:
|
||
"""列出所有已注册的Provider名称"""
|
||
return list(cls._providers.keys())
|
||
|
||
|
||
# ---- 自动注册内置Provider ----
|
||
LLMFactory.register("openai", OpenAIProvider)
|
||
LLMFactory.register("deepseek", DeepSeekProvider)
|