geo/backend/app/services/llm/factory.py

66 lines
2.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from .base import LLMError, LLMProvider
from .openai_provider import OpenAIProvider
from .deepseek_provider import DeepSeekProvider
class LLMFactory:
"""LLM Provider工厂 - 统一创建和管理Provider实例"""
_providers: dict[str, type[LLMProvider]] = {}
@classmethod
def register(cls, name: str, provider_cls: type[LLMProvider]) -> None:
"""
注册Provider类
Args:
name: Provider名称"openai", "deepseek"
provider_cls: LLMProvider子类
"""
cls._providers[name.lower()] = provider_cls
@classmethod
def create(cls, provider: str | None = None, model: str | None = None) -> LLMProvider:
"""
创建LLM Provider实例
Args:
provider: 提供商名称,如 "openai" | "deepseek"
为None时从环境变量 DEFAULT_LLM_PROVIDER 读取,默认 "openai"
model: 覆盖默认模型名
Returns:
LLMProvider实例
Raises:
LLMError: provider未注册时
"""
provider_name = (provider or os.getenv("DEFAULT_LLM_PROVIDER", "openai")).lower()
if provider_name not in cls._providers:
available = ", ".join(cls._providers.keys()) or "(无)"
raise LLMError(
f"未注册的Provider: {provider_name},可用: {available}",
provider="factory",
)
provider_cls = cls._providers[provider_name]
return provider_cls(model=model) if model else provider_cls()
@classmethod
def get_default(cls) -> LLMProvider:
"""获取默认provider等价于 create()"""
return cls.create()
@classmethod
def list_providers(cls) -> list[str]:
"""列出所有已注册的Provider名称"""
return list(cls._providers.keys())
# ---- 自动注册内置Provider ----
LLMFactory.register("openai", OpenAIProvider)
LLMFactory.register("deepseek", DeepSeekProvider)