129 lines
4.7 KiB
Python
129 lines
4.7 KiB
Python
"""LLM Config - 配置加载"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
import yaml
|
|
|
|
from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig
|
|
|
|
|
|
@dataclass
|
|
class CacheConfig:
|
|
"""LLM Cache 配置"""
|
|
|
|
enabled: bool = False
|
|
backend: str = "auto" # "auto" | "redis" | "memory"
|
|
redis_url: str = "redis://localhost:6379"
|
|
exact_ttl: int = 3600
|
|
semantic_ttl: int = 86400
|
|
similarity_threshold: float = 0.92
|
|
max_entries: int = 10000
|
|
# Embedding config for semantic cache (Chinese-first: bge-m3 via Xinference)
|
|
embedding_provider: str = "openai" # "openai" | "xinference" | "local"
|
|
embedding_model: str = "bge-m3"
|
|
embedding_base_url: str | None = None
|
|
embedding_api_key: str | None = None
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict) -> "CacheConfig":
|
|
if not data:
|
|
return cls()
|
|
emb = data.get("embedding", {})
|
|
return cls(
|
|
enabled=data.get("enabled", False),
|
|
backend=data.get("backend", "auto"),
|
|
redis_url=data.get("redis_url", "redis://localhost:6379"),
|
|
exact_ttl=data.get("exact_ttl", 3600),
|
|
semantic_ttl=data.get("semantic_ttl", 86400),
|
|
similarity_threshold=data.get("similarity_threshold", 0.92),
|
|
max_entries=data.get("max_entries", 10000),
|
|
embedding_provider=emb.get("provider", "openai"),
|
|
embedding_model=emb.get("model", "bge-m3"),
|
|
embedding_base_url=emb.get("base_url"),
|
|
embedding_api_key=emb.get("api_key"),
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ProviderConfig:
|
|
"""Provider 配置"""
|
|
|
|
api_key: str
|
|
base_url: str
|
|
models: dict[str, dict[str, Any]] = field(default_factory=dict)
|
|
type: str = "openai" # "openai" | "anthropic" | "gemini"
|
|
max_tokens: int = 4096 # Anthropic: default max_tokens
|
|
timeout: float = 120.0 # Anthropic: request timeout
|
|
max_connections: int = 100 # httpx 连接池最大连接数
|
|
max_keepalive_connections: int = 20 # httpx 连接池最大保活连接数
|
|
keepalive_expiry: float = 30.0 # httpx 保活连接过期时间(秒)
|
|
retry: RetryConfig | None = None
|
|
circuit_breaker: CircuitBreakerConfig | None = None
|
|
|
|
|
|
@dataclass
|
|
class LLMConfig:
|
|
"""LLM 配置"""
|
|
|
|
providers: dict[str, ProviderConfig] = field(default_factory=dict)
|
|
model_aliases: dict[str, str] = field(default_factory=dict)
|
|
fallbacks: dict[str, list[str]] = field(default_factory=dict)
|
|
cache: CacheConfig | None = None
|
|
|
|
@classmethod
|
|
def from_yaml(cls, path: str) -> "LLMConfig":
|
|
"""从 YAML 文件加载配置"""
|
|
with open(path, encoding="utf-8") as f:
|
|
data = yaml.safe_load(f)
|
|
return cls.from_dict(data or {})
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict) -> "LLMConfig":
|
|
"""从字典加载配置"""
|
|
providers = {}
|
|
for name, pconf in data.get("providers", {}).items():
|
|
retry = None
|
|
retry_data = pconf.get("retry")
|
|
if retry_data:
|
|
retry = RetryConfig(
|
|
max_retries=retry_data.get("max_retries", 3),
|
|
base_delay=retry_data.get("base_delay", 1.0),
|
|
max_delay=retry_data.get("max_delay", 30.0),
|
|
exponential_base=retry_data.get("exponential_base", 2.0),
|
|
)
|
|
|
|
circuit_breaker = None
|
|
cb_data = pconf.get("circuit_breaker")
|
|
if cb_data:
|
|
circuit_breaker = CircuitBreakerConfig(
|
|
failure_threshold=cb_data.get("failure_threshold", 5),
|
|
recovery_timeout=cb_data.get("recovery_timeout", 60.0),
|
|
half_open_max=cb_data.get("half_open_max", 1),
|
|
)
|
|
|
|
providers[name] = ProviderConfig(
|
|
api_key=pconf.get("api_key", ""),
|
|
base_url=pconf.get("base_url", ""),
|
|
models=pconf.get("models", {}),
|
|
type=pconf.get("type", "openai"),
|
|
max_tokens=pconf.get("max_tokens", 4096),
|
|
timeout=pconf.get("timeout", 120.0),
|
|
max_connections=pconf.get("max_connections", 100),
|
|
max_keepalive_connections=pconf.get("max_keepalive_connections", 20),
|
|
keepalive_expiry=pconf.get("keepalive_expiry", 30.0),
|
|
retry=retry,
|
|
circuit_breaker=circuit_breaker,
|
|
)
|
|
cache = None
|
|
cache_data = data.get("cache")
|
|
if cache_data:
|
|
cache = CacheConfig.from_dict(cache_data)
|
|
|
|
return cls(
|
|
providers=providers,
|
|
model_aliases=data.get("model_aliases", {}),
|
|
fallbacks=data.get("fallbacks", {}),
|
|
cache=cache,
|
|
)
|