fischer-agentkit/src/agentkit/llm/config.py

129 lines
4.7 KiB
Python

"""LLM Config - 配置加载"""
from dataclasses import dataclass, field
from typing import Any
import yaml
from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig
@dataclass
class CacheConfig:
"""LLM Cache 配置"""
enabled: bool = False
backend: str = "auto" # "auto" | "redis" | "memory"
redis_url: str = "redis://localhost:6379"
exact_ttl: int = 3600
semantic_ttl: int = 86400
similarity_threshold: float = 0.92
max_entries: int = 10000
# Embedding config for semantic cache (Chinese-first: bge-m3 via Xinference)
embedding_provider: str = "openai" # "openai" | "xinference" | "local"
embedding_model: str = "bge-m3"
embedding_base_url: str | None = None
embedding_api_key: str | None = None
@classmethod
def from_dict(cls, data: dict) -> "CacheConfig":
if not data:
return cls()
emb = data.get("embedding", {})
return cls(
enabled=data.get("enabled", False),
backend=data.get("backend", "auto"),
redis_url=data.get("redis_url", "redis://localhost:6379"),
exact_ttl=data.get("exact_ttl", 3600),
semantic_ttl=data.get("semantic_ttl", 86400),
similarity_threshold=data.get("similarity_threshold", 0.92),
max_entries=data.get("max_entries", 10000),
embedding_provider=emb.get("provider", "openai"),
embedding_model=emb.get("model", "bge-m3"),
embedding_base_url=emb.get("base_url"),
embedding_api_key=emb.get("api_key"),
)
@dataclass
class ProviderConfig:
"""Provider 配置"""
api_key: str
base_url: str
models: dict[str, dict[str, Any]] = field(default_factory=dict)
type: str = "openai" # "openai" | "anthropic" | "gemini"
max_tokens: int = 4096 # Anthropic: default max_tokens
timeout: float = 120.0 # Anthropic: request timeout
max_connections: int = 100 # httpx 连接池最大连接数
max_keepalive_connections: int = 20 # httpx 连接池最大保活连接数
keepalive_expiry: float = 30.0 # httpx 保活连接过期时间(秒)
retry: RetryConfig | None = None
circuit_breaker: CircuitBreakerConfig | None = None
@dataclass
class LLMConfig:
"""LLM 配置"""
providers: dict[str, ProviderConfig] = field(default_factory=dict)
model_aliases: dict[str, str] = field(default_factory=dict)
fallbacks: dict[str, list[str]] = field(default_factory=dict)
cache: CacheConfig | None = None
@classmethod
def from_yaml(cls, path: str) -> "LLMConfig":
"""从 YAML 文件加载配置"""
with open(path, encoding="utf-8") as f:
data = yaml.safe_load(f)
return cls.from_dict(data or {})
@classmethod
def from_dict(cls, data: dict) -> "LLMConfig":
"""从字典加载配置"""
providers = {}
for name, pconf in data.get("providers", {}).items():
retry = None
retry_data = pconf.get("retry")
if retry_data:
retry = RetryConfig(
max_retries=retry_data.get("max_retries", 3),
base_delay=retry_data.get("base_delay", 1.0),
max_delay=retry_data.get("max_delay", 30.0),
exponential_base=retry_data.get("exponential_base", 2.0),
)
circuit_breaker = None
cb_data = pconf.get("circuit_breaker")
if cb_data:
circuit_breaker = CircuitBreakerConfig(
failure_threshold=cb_data.get("failure_threshold", 5),
recovery_timeout=cb_data.get("recovery_timeout", 60.0),
half_open_max=cb_data.get("half_open_max", 1),
)
providers[name] = ProviderConfig(
api_key=pconf.get("api_key", ""),
base_url=pconf.get("base_url", ""),
models=pconf.get("models", {}),
type=pconf.get("type", "openai"),
max_tokens=pconf.get("max_tokens", 4096),
timeout=pconf.get("timeout", 120.0),
max_connections=pconf.get("max_connections", 100),
max_keepalive_connections=pconf.get("max_keepalive_connections", 20),
keepalive_expiry=pconf.get("keepalive_expiry", 30.0),
retry=retry,
circuit_breaker=circuit_breaker,
)
cache = None
cache_data = data.get("cache")
if cache_data:
cache = CacheConfig.from_dict(cache_data)
return cls(
providers=providers,
model_aliases=data.get("model_aliases", {}),
fallbacks=data.get("fallbacks", {}),
cache=cache,
)