"""Server configuration loader - loads agentkit.yaml and .env""" import asyncio import logging import os import re from dataclasses import dataclass from pathlib import Path from typing import Any, Callable import yaml from agentkit.llm.config import LLMConfig, ProviderConfig from agentkit.skills.base import SkillConfig logger = logging.getLogger(__name__) # Default config file name DEFAULT_CONFIG_FILE = "agentkit.yaml" @dataclass class MCPServerConfig: """Configuration for a single MCP Server connection""" transport: str # "stdio" | "streamable_http" | "sse" # stdio-specific command: str | None = None args: list[str] | None = None env: dict[str, str] | None = None # http/sse-specific url: str | None = None headers: dict[str, str] | None = None # common timeout: float = 30.0 def validate(self) -> None: """Validate configuration, raise ValueError if invalid""" if self.transport not in ("stdio", "streamable_http", "sse"): raise ValueError(f"Invalid transport: {self.transport}") if self.transport == "stdio" and not self.command: raise ValueError("stdio transport requires 'command'") if self.transport in ("streamable_http", "sse") and not self.url: raise ValueError(f"{self.transport} transport requires 'url'") @classmethod def from_dict(cls, data: dict) -> "MCPServerConfig": """Create from dict (parsed from YAML)""" return cls( transport=data.get("transport", "stdio"), command=data.get("command"), args=data.get("args"), env=data.get("env"), url=data.get("url"), headers=data.get("headers"), timeout=data.get("timeout", 30.0), ) def _resolve_env_vars(value: Any) -> Any: """Resolve ${VAR:-default} patterns in string values from environment variables.""" if not isinstance(value, str): return value pattern = re.compile(r"\$\{([^}]+)\}") def replacer(match): expr = match.group(1) if ":-" in expr: var_name, default = expr.split(":-", 1) return os.environ.get(var_name, default) return os.environ.get(expr, match.group(0)) return pattern.sub(replacer, value) def _deep_resolve(data: Any) -> Any: """Recursively resolve env vars in nested dicts/lists.""" if isinstance(data, dict): return {k: _deep_resolve(v) for k, v in data.items()} if isinstance(data, list): return [_deep_resolve(item) for item in data] if isinstance(data, str): return _resolve_env_vars(data) return data class ServerConfig: """Server configuration loaded from agentkit.yaml""" def __init__( self, host: str = "0.0.0.0", port: int = 18001, workers: int = 1, api_key: str | None = None, rate_limit: int = 60, llm_config: LLMConfig | None = None, skill_paths: list[str] | None = None, auto_discover_skills: bool = True, log_level: str = "INFO", log_format: str = "text", task_store: dict[str, Any] | None = None, cors_origins: list[str] | None = None, memory: dict[str, Any] | None = None, mcp_servers: dict[str, MCPServerConfig] | None = None, telemetry: dict[str, Any] | None = None, compression: dict[str, Any] | None = None, session: dict[str, Any] | None = None, bus: dict[str, Any] | None = None, marketplace: dict[str, Any] | None = None, alignment: dict[str, Any] | None = None, router: dict[str, Any] | None = None, usage_store: dict[str, Any] | None = None, cascade_store: dict[str, Any] | None = None, evolution: dict[str, Any] | None = None, expert_paths: list[str] | None = None, board: dict[str, Any] | None = None, prompt_cache: dict[str, Any] | None = None, streaming: dict[str, Any] | None = None, verification: dict[str, Any] | None = None, rollback: dict[str, Any] | None = None, fallback_chain: dict[str, Any] | None = None, # G6/U2: PLAN_EXEC phase policy config (opt-in — None = disabled). # Parsed via PhasePolicy.policy_from_config() at chat.py wiring time. plan_exec: dict[str, Any] | None = None, on_change: Callable[["ServerConfig"], None] | None = None, ): self.host = host self.port = port self.workers = workers self.api_key = api_key self.rate_limit = rate_limit self.llm_config = llm_config or LLMConfig() self.skill_paths = skill_paths or [] self.auto_discover_skills = auto_discover_skills self.log_level = log_level self.log_format = log_format self.task_store = task_store or {} self.cors_origins = cors_origins or ["*"] self.memory = memory or {} self.mcp_servers = mcp_servers or {} self.telemetry = telemetry or {} self.compression = compression or {} self.session = session or {} self.bus = bus or {} self.marketplace = marketplace or {} self.alignment = alignment or {} self.router = router or {} self.usage_store = usage_store or {} self.cascade_store = cascade_store or {} self.evolution = evolution or {} self.expert_paths = expert_paths or [] self.board = board or {} self.prompt_cache = prompt_cache or {} # U3/G8: streaming.flush_interval_ms 控制 token chunk 节流(默认 0 = 逐 chunk yield) self.streaming = streaming or {} # U4/G1: verification.max_reinjections 控制 verify 失败回灌次数(默认 1) # verification_enabled=False 时此配置无效 self.verification = verification or {} # G9/U4: rollback.default_timeout 控制 RollbackExecutor subprocess 超时 # PlanPhase.rollback_command 未设置时此配置无效 (KTD6 opt-in) self.rollback = rollback or {} # G7/U3: fallback_chain.{recovery,emergency}.{enabled,max_retries} # controls three-tier chain at chat.py REST send_message (KTD5). self.fallback_chain = fallback_chain or {} # G6/U2: plan_exec phase policy config (opt-in — empty dict = disabled). # Resolved to PhasePolicy via agentkit.core.phase.policy_from_config() # at chat.py WebSocket wiring time (U4). self.plan_exec = plan_exec or {} self.on_change = on_change # Config watching state self._config_path: str | None = None self._watcher_task: asyncio.Task | None = None self._last_mtime: float = 0.0 def has_llm_provider(self) -> bool: """检查是否配置了有效的 LLM Provider(API Key 已解析且非空)""" for name, provider in self.llm_config.providers.items(): if provider.api_key and not provider.api_key.startswith("${"): return True return False @classmethod def from_yaml(cls, path: str) -> "ServerConfig": """Load configuration from a YAML file.""" with open(path, encoding="utf-8") as f: data = yaml.safe_load(f) or {} # Resolve environment variables data = _deep_resolve(data) config = cls.from_dict(data) config._config_path = path config._last_mtime = os.path.getmtime(path) return config @classmethod def from_dict(cls, data: dict) -> "ServerConfig": """Create ServerConfig from a dictionary.""" server = data.get("server", {}) llm_data = data.get("llm", {}) skills_data = data.get("skills", {}) logging_data = data.get("logging", {}) task_store_data = data.get("task_store", {}) memory_data = data.get("memory", {}) mcp_data = data.get("mcp", {}) # Build LLMConfig llm_config = cls._build_llm_config(llm_data) # Build skill paths skill_paths = skills_data.get("paths", []) auto_discover = skills_data.get("auto_discover", True) # Build MCP server configs mcp_servers = cls._build_mcp_configs(mcp_data) # Telemetry config telemetry_data = data.get("telemetry", {}) # Compression config compression_data = data.get("compression", {}) # Session config session_data = data.get("session", {}) # Marketplace config marketplace_data = data.get("marketplace", {}) # Alignment config alignment_data = data.get("alignment", {}) # Router config router_data = data.get("router", {}) # Usage store config usage_store_data = data.get("usage_store", {}) # Cascade store config cascade_store_data = data.get("cascade_store", {}) # Evolution store config evolution_data = data.get("evolution", {}) # Expert templates config (paths to YAML files defining ExpertTemplates) experts_data = data.get("experts", {}) expert_paths = experts_data.get("paths", []) # Board meeting config (max_rounds, default_template, etc.) board_data = data.get("board", {}) # U2/U3/U4: prompt_cache / streaming / verification 配置(从 YAML 读取) prompt_cache_data = data.get("prompt_cache", {}) streaming_data = data.get("streaming", {}) verification_data = data.get("verification", {}) # G9/U4: rollback 配置 (从 YAML 读取,opt-in) rollback_data = data.get("rollback", {}) # G7/U3: fallback_chain 配置 (从 YAML 读取) fallback_chain_data = data.get("fallback_chain", {}) # G6/U2: plan_exec phase policy 配置 (从 YAML 读取, opt-in) plan_exec_data = data.get("plan_exec", {}) return cls( host=server.get("host", "0.0.0.0"), port=server.get("port", 18001), workers=server.get("workers", 1), api_key=server.get("api_key"), rate_limit=server.get("rate_limit", 60), llm_config=llm_config, skill_paths=skill_paths, auto_discover_skills=auto_discover, log_level=logging_data.get("level", "INFO"), log_format=logging_data.get("format", "text"), task_store=task_store_data, cors_origins=server.get("cors_origins"), memory=memory_data, mcp_servers=mcp_servers, telemetry=telemetry_data, compression=compression_data, session=session_data, bus=server.get("bus"), marketplace=marketplace_data, alignment=alignment_data, router=router_data, usage_store=usage_store_data, cascade_store=cascade_store_data, evolution=evolution_data, expert_paths=expert_paths, board=board_data, prompt_cache=prompt_cache_data, streaming=streaming_data, verification=verification_data, rollback=rollback_data, fallback_chain=fallback_chain_data, plan_exec=plan_exec_data, ) @staticmethod def _build_llm_config(data: dict) -> LLMConfig: """Build LLMConfig from the llm section of agentkit.yaml.""" from agentkit.llm.config import CacheConfig providers = {} model_aliases = {} for name, pconf in data.get("providers", {}).items(): api_key = pconf.get("api_key", "") base_url = pconf.get("base_url", "") models = pconf.get("models", {}) # Build model aliases from alias fields within model configs for model_name, model_conf in models.items(): alias = model_conf.get("alias") if isinstance(model_conf, dict) else None if alias: model_aliases[alias] = f"{name}/{model_name}" providers[name] = ProviderConfig( api_key=api_key, base_url=base_url, models=models, type=pconf.get("type", "openai"), max_tokens=pconf.get("max_tokens", 4096), timeout=pconf.get("timeout", 120.0), max_connections=pconf.get("max_connections", 100), max_keepalive_connections=pconf.get("max_keepalive_connections", 20), keepalive_expiry=pconf.get("keepalive_expiry", 30.0), ) # Merge top-level model_aliases from YAML (takes precedence over inline alias fields) top_level_aliases = data.get("model_aliases", {}) if isinstance(top_level_aliases, dict): model_aliases.update(top_level_aliases) # Build CacheConfig if cache section is present cache_config = None cache_data = data.get("cache") if cache_data and isinstance(cache_data, dict): cache_config = CacheConfig.from_dict(cache_data) return LLMConfig( providers=providers, model_aliases=model_aliases, fallbacks=data.get("fallbacks", {}), cache=cache_config, # G4/U1: auxiliary model alias for cost-sensitive summarization. # Resolved via model_aliases; None = no auxiliary routing. auxiliary_model=data.get("auxiliary_model"), ) @staticmethod def _build_mcp_configs(data: dict) -> dict[str, MCPServerConfig]: """Build MCP server configs from the mcp section of agentkit.yaml.""" servers = data.get("servers", {}) if not servers: return {} result = {} for name, server_conf in servers.items(): if isinstance(server_conf, dict): result[name] = MCPServerConfig.from_dict(server_conf) return result def load_skill_configs(self) -> list[SkillConfig]: """Load all SkillConfig from configured skill paths.""" configs = [] for skill_path in self.skill_paths: path = Path(skill_path) if path.is_file() and path.suffix in (".yaml", ".yml"): try: config = SkillConfig.from_yaml(str(path)) configs.append(config) logger.info(f"Loaded skill config: {config.name} from {path}") except Exception as e: logger.warning(f"Failed to load skill config from {path}: {e}") elif path.is_dir(): for yaml_file in sorted(path.glob("*.yaml")): try: config = SkillConfig.from_yaml(str(yaml_file)) configs.append(config) logger.info(f"Loaded skill config: {config.name} from {yaml_file}") except Exception as e: logger.warning(f"Failed to load skill config from {yaml_file}: {e}") for yaml_file in sorted(path.glob("*.yml")): try: config = SkillConfig.from_yaml(str(yaml_file)) configs.append(config) logger.info(f"Loaded skill config: {config.name} from {yaml_file}") except Exception as e: logger.warning(f"Failed to load skill config from {yaml_file}: {e}") return configs def watch_config(self, config_path: str | None = None) -> None: """Start watching the config file for changes and hot-reload. Uses watchfiles if available, otherwise falls back to asyncio polling (checks mtime every 30 seconds). Args: config_path: Path to the config file. If None, uses the path from the last from_yaml() call. """ path = config_path or self._config_path if not path: logger.warning("No config path specified for watching") return self._config_path = path if not self._last_mtime: try: self._last_mtime = os.path.getmtime(path) except OSError: self._last_mtime = 0.0 try: import watchfiles # noqa: F401 self._watcher_task = asyncio.ensure_future(self._watch_with_watchfiles(path)) logger.info(f"Config watcher started (watchfiles) for {path}") except ImportError: self._watcher_task = asyncio.ensure_future(self._poll_config_loop(path)) logger.info(f"Config watcher started (polling) for {path}") def stop_watching(self) -> None: """Stop watching the config file.""" if self._watcher_task is not None and not self._watcher_task.done(): self._watcher_task.cancel() logger.info("Config watcher stopped") self._watcher_task = None async def _watch_with_watchfiles(self, path: str) -> None: """Watch config file using watchfiles library.""" try: from watchfiles import awatch async for changes in awatch(path): for change_type, changed_path in changes: logger.info(f"Config file change detected: {change_type} on {changed_path}") self._try_reload_config(path) except asyncio.CancelledError: pass except Exception as e: logger.error(f"watchfiles error, falling back to polling: {e}") self._watcher_task = asyncio.ensure_future(self._poll_config_loop(path)) async def _poll_config_loop(self, path: str) -> None: """Fallback: poll config file mtime every 30 seconds.""" try: while True: await asyncio.sleep(30) try: current_mtime = os.path.getmtime(path) except OSError: continue if current_mtime != self._last_mtime: logger.info(f"Config file change detected (mtime) for {path}") self._last_mtime = current_mtime self._try_reload_config(path) except asyncio.CancelledError: pass def _try_reload_config(self, path: str) -> None: """Attempt to reload config from file. On failure, keep current config.""" try: new_config = ServerConfig.from_yaml(path) except Exception as e: logger.error(f"Failed to reload config from {path}: {e}. Keeping current config.") return # Validate basic structure: must have at least a server or llm section if not hasattr(new_config, "host") or not hasattr(new_config, "llm_config"): logger.error(f"Invalid config structure in {path}. Keeping current config.") return # Apply new values self.host = new_config.host self.port = new_config.port self.workers = new_config.workers self.api_key = new_config.api_key self.rate_limit = new_config.rate_limit self.llm_config = new_config.llm_config self.skill_paths = new_config.skill_paths self.auto_discover_skills = new_config.auto_discover_skills self.log_level = new_config.log_level self.log_format = new_config.log_format self.task_store = new_config.task_store self.cors_origins = new_config.cors_origins self.memory = new_config.memory self.mcp_servers = new_config.mcp_servers self.telemetry = new_config.telemetry self.compression = new_config.compression self.session = new_config.session self.marketplace = new_config.marketplace self.alignment = new_config.alignment self.router = new_config.router self.expert_paths = new_config.expert_paths self.board = new_config.board self._last_mtime = new_config._last_mtime logger.info(f"Config reloaded from {path}") if self.on_change is not None: try: self.on_change(self) except Exception as e: logger.error(f"Config on_change callback error: {e}") # ── .env loading ─────────────────────────────────────────────────────── _ALLOWED_ENV_PREFIXES = ( "AGENTKIT_", "DASHSCOPE_", "OPENAI_", "ANTHROPIC_", "GEMINI_", "TAVILY_", "SERPER_", "DEEPSEEK_", "DOUBAO_", ) _ALLOWED_ENV_EXACT = {"DATABASE_URL", "REDIS_URL"} def load_dotenv( dotenv_path: str | Path, *, allowed_prefixes: tuple[str, ...] | None = None, allowed_exact: set[str] | None = None, ) -> None: """Load environment variables from a .env file. Only variables matching allowed prefixes or exact names are loaded. Existing environment variables are never overwritten. Args: dotenv_path: Path to the .env file. allowed_prefixes: Env var prefixes to allow. None = allow all. allowed_exact: Exact env var names to allow. None = allow all. """ path = Path(dotenv_path) if not path.exists(): return prefixes = allowed_prefixes exact = allowed_exact with open(path, encoding="utf-8") as f: for line in f: line = line.strip() if not line or line.startswith("#"): continue if "=" not in line: continue key, _, value = line.partition("=") key = key.strip() value = value.strip().strip("\"'") # Skip only if key is set to a non-empty value in the environment. # An empty/whitespace-only value (e.g. from a shell template like # `${VAR:-}` that expanded to nothing) is treated as "not set" so # subsequent .env files can still provide a real value. if not key or (key in os.environ and os.environ[key].strip()): continue # Apply allowlist if provided if prefixes is not None or exact is not None: allowed = False if prefixes and any(key.startswith(p) for p in prefixes): allowed = True if exact and key in exact: allowed = True if not allowed: logger.warning(f"Skipping .env variable '{key}' (not in allowed list)") continue os.environ[key] = value def load_config_with_dotenv(config_path: str | Path) -> ServerConfig: """Load ServerConfig with .env resolution (production-standard loading). 1. Load .env from config file's parent directory 2. Parse agentkit.yaml with env vars resolved This is the canonical way to load config in all CLI commands and app factory. """ config_path = str(config_path) config_dir = Path(config_path).parent # Load .env, .env.dev, .env.local in order (first non-empty value wins). for candidate in (".env", ".env.dev", ".env.local"): load_dotenv(config_dir / candidate) return ServerConfig.from_yaml(config_path) def find_config_path(config_arg: str | None = None) -> str | None: """Find the agentkit.yaml config file. Priority: 1. Explicit --config argument 2. ./agentkit.yaml in current directory 3. ~/.agentkit/agentkit.yaml in home directory """ if config_arg: if Path(config_arg).exists(): return config_arg logger.warning(f"Config file not found: {config_arg}") return None # Check current directory cwd_config = Path.cwd() / DEFAULT_CONFIG_FILE if cwd_config.exists(): return str(cwd_config) # Check home directory home_config = Path.home() / ".agentkit" / DEFAULT_CONFIG_FILE if home_config.exists(): return str(home_config) return None