585 lines
21 KiB
Python
585 lines
21 KiB
Python
"""Server configuration loader - loads agentkit.yaml and .env"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
import re
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
from typing import Any, Callable
|
||
|
||
import yaml
|
||
|
||
from agentkit.llm.config import LLMConfig, ProviderConfig
|
||
from agentkit.skills.base import SkillConfig
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Default config file name
|
||
DEFAULT_CONFIG_FILE = "agentkit.yaml"
|
||
|
||
|
||
@dataclass
|
||
class MCPServerConfig:
|
||
"""Configuration for a single MCP Server connection"""
|
||
|
||
transport: str # "stdio" | "streamable_http" | "sse"
|
||
# stdio-specific
|
||
command: str | None = None
|
||
args: list[str] | None = None
|
||
env: dict[str, str] | None = None
|
||
# http/sse-specific
|
||
url: str | None = None
|
||
headers: dict[str, str] | None = None
|
||
# common
|
||
timeout: float = 30.0
|
||
|
||
def validate(self) -> None:
|
||
"""Validate configuration, raise ValueError if invalid"""
|
||
if self.transport not in ("stdio", "streamable_http", "sse"):
|
||
raise ValueError(f"Invalid transport: {self.transport}")
|
||
if self.transport == "stdio" and not self.command:
|
||
raise ValueError("stdio transport requires 'command'")
|
||
if self.transport in ("streamable_http", "sse") and not self.url:
|
||
raise ValueError(f"{self.transport} transport requires 'url'")
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: dict) -> "MCPServerConfig":
|
||
"""Create from dict (parsed from YAML)"""
|
||
return cls(
|
||
transport=data.get("transport", "stdio"),
|
||
command=data.get("command"),
|
||
args=data.get("args"),
|
||
env=data.get("env"),
|
||
url=data.get("url"),
|
||
headers=data.get("headers"),
|
||
timeout=data.get("timeout", 30.0),
|
||
)
|
||
|
||
|
||
def _resolve_env_vars(value: Any) -> Any:
|
||
"""Resolve ${VAR:-default} patterns in string values from environment variables."""
|
||
if not isinstance(value, str):
|
||
return value
|
||
|
||
pattern = re.compile(r"\$\{([^}]+)\}")
|
||
|
||
def replacer(match):
|
||
expr = match.group(1)
|
||
if ":-" in expr:
|
||
var_name, default = expr.split(":-", 1)
|
||
return os.environ.get(var_name, default)
|
||
return os.environ.get(expr, match.group(0))
|
||
|
||
return pattern.sub(replacer, value)
|
||
|
||
|
||
def _deep_resolve(data: Any) -> Any:
|
||
"""Recursively resolve env vars in nested dicts/lists."""
|
||
if isinstance(data, dict):
|
||
return {k: _deep_resolve(v) for k, v in data.items()}
|
||
if isinstance(data, list):
|
||
return [_deep_resolve(item) for item in data]
|
||
if isinstance(data, str):
|
||
return _resolve_env_vars(data)
|
||
return data
|
||
|
||
|
||
class ServerConfig:
|
||
"""Server configuration loaded from agentkit.yaml"""
|
||
|
||
def __init__(
|
||
self,
|
||
host: str = "0.0.0.0",
|
||
port: int = 8001,
|
||
workers: int = 1,
|
||
api_key: str | None = None,
|
||
rate_limit: int = 60,
|
||
llm_config: LLMConfig | None = None,
|
||
skill_paths: list[str] | None = None,
|
||
auto_discover_skills: bool = True,
|
||
log_level: str = "INFO",
|
||
log_format: str = "text",
|
||
task_store: dict[str, Any] | None = None,
|
||
cors_origins: list[str] | None = None,
|
||
memory: dict[str, Any] | None = None,
|
||
mcp_servers: dict[str, MCPServerConfig] | None = None,
|
||
telemetry: dict[str, Any] | None = None,
|
||
compression: dict[str, Any] | None = None,
|
||
session: dict[str, Any] | None = None,
|
||
bus: dict[str, Any] | None = None,
|
||
marketplace: dict[str, Any] | None = None,
|
||
alignment: dict[str, Any] | None = None,
|
||
router: dict[str, Any] | None = None,
|
||
usage_store: dict[str, Any] | None = None,
|
||
cascade_store: dict[str, Any] | None = None,
|
||
evolution: dict[str, Any] | None = None,
|
||
expert_paths: list[str] | None = None,
|
||
board: dict[str, Any] | None = None,
|
||
prompt_cache: dict[str, Any] | None = None,
|
||
streaming: dict[str, Any] | None = None,
|
||
verification: dict[str, Any] | None = None,
|
||
on_change: Callable[["ServerConfig"], None] | None = None,
|
||
):
|
||
self.host = host
|
||
self.port = port
|
||
self.workers = workers
|
||
self.api_key = api_key
|
||
self.rate_limit = rate_limit
|
||
self.llm_config = llm_config or LLMConfig()
|
||
self.skill_paths = skill_paths or []
|
||
self.auto_discover_skills = auto_discover_skills
|
||
self.log_level = log_level
|
||
self.log_format = log_format
|
||
self.task_store = task_store or {}
|
||
self.cors_origins = cors_origins or ["*"]
|
||
self.memory = memory or {}
|
||
self.mcp_servers = mcp_servers or {}
|
||
self.telemetry = telemetry or {}
|
||
self.compression = compression or {}
|
||
self.session = session or {}
|
||
self.bus = bus or {}
|
||
self.marketplace = marketplace or {}
|
||
self.alignment = alignment or {}
|
||
self.router = router or {}
|
||
self.usage_store = usage_store or {}
|
||
self.cascade_store = cascade_store or {}
|
||
self.evolution = evolution or {}
|
||
self.expert_paths = expert_paths or []
|
||
self.board = board or {}
|
||
self.prompt_cache = prompt_cache or {}
|
||
# U3/G8: streaming.flush_interval_ms 控制 token chunk 节流(默认 0 = 逐 chunk yield)
|
||
self.streaming = streaming or {}
|
||
# U4/G1: verification.max_reinjections 控制 verify 失败回灌次数(默认 1)
|
||
# verification_enabled=False 时此配置无效
|
||
self.verification = verification or {}
|
||
self.on_change = on_change
|
||
|
||
# Config watching state
|
||
self._config_path: str | None = None
|
||
self._watcher_task: asyncio.Task | None = None
|
||
self._last_mtime: float = 0.0
|
||
|
||
def has_llm_provider(self) -> bool:
|
||
"""检查是否配置了有效的 LLM Provider(API Key 已解析且非空)"""
|
||
for name, provider in self.llm_config.providers.items():
|
||
if provider.api_key and not provider.api_key.startswith("${"):
|
||
return True
|
||
return False
|
||
|
||
@classmethod
|
||
def from_yaml(cls, path: str) -> "ServerConfig":
|
||
"""Load configuration from a YAML file."""
|
||
with open(path, encoding="utf-8") as f:
|
||
data = yaml.safe_load(f) or {}
|
||
|
||
# Resolve environment variables
|
||
data = _deep_resolve(data)
|
||
|
||
config = cls.from_dict(data)
|
||
config._config_path = path
|
||
config._last_mtime = os.path.getmtime(path)
|
||
return config
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: dict) -> "ServerConfig":
|
||
"""Create ServerConfig from a dictionary."""
|
||
server = data.get("server", {})
|
||
llm_data = data.get("llm", {})
|
||
skills_data = data.get("skills", {})
|
||
logging_data = data.get("logging", {})
|
||
task_store_data = data.get("task_store", {})
|
||
memory_data = data.get("memory", {})
|
||
mcp_data = data.get("mcp", {})
|
||
|
||
# Build LLMConfig
|
||
llm_config = cls._build_llm_config(llm_data)
|
||
|
||
# Build skill paths
|
||
skill_paths = skills_data.get("paths", [])
|
||
auto_discover = skills_data.get("auto_discover", True)
|
||
|
||
# Build MCP server configs
|
||
mcp_servers = cls._build_mcp_configs(mcp_data)
|
||
|
||
# Telemetry config
|
||
telemetry_data = data.get("telemetry", {})
|
||
|
||
# Compression config
|
||
compression_data = data.get("compression", {})
|
||
|
||
# Session config
|
||
session_data = data.get("session", {})
|
||
|
||
# Marketplace config
|
||
marketplace_data = data.get("marketplace", {})
|
||
|
||
# Alignment config
|
||
alignment_data = data.get("alignment", {})
|
||
|
||
# Router config
|
||
router_data = data.get("router", {})
|
||
|
||
# Usage store config
|
||
usage_store_data = data.get("usage_store", {})
|
||
|
||
# Cascade store config
|
||
cascade_store_data = data.get("cascade_store", {})
|
||
|
||
# Evolution store config
|
||
evolution_data = data.get("evolution", {})
|
||
|
||
# Expert templates config (paths to YAML files defining ExpertTemplates)
|
||
experts_data = data.get("experts", {})
|
||
expert_paths = experts_data.get("paths", [])
|
||
|
||
# Board meeting config (max_rounds, default_template, etc.)
|
||
board_data = data.get("board", {})
|
||
|
||
# U2/U3/U4: prompt_cache / streaming / verification 配置(从 YAML 读取)
|
||
prompt_cache_data = data.get("prompt_cache", {})
|
||
streaming_data = data.get("streaming", {})
|
||
verification_data = data.get("verification", {})
|
||
|
||
return cls(
|
||
host=server.get("host", "0.0.0.0"),
|
||
port=server.get("port", 8001),
|
||
workers=server.get("workers", 1),
|
||
api_key=server.get("api_key"),
|
||
rate_limit=server.get("rate_limit", 60),
|
||
llm_config=llm_config,
|
||
skill_paths=skill_paths,
|
||
auto_discover_skills=auto_discover,
|
||
log_level=logging_data.get("level", "INFO"),
|
||
log_format=logging_data.get("format", "text"),
|
||
task_store=task_store_data,
|
||
cors_origins=server.get("cors_origins"),
|
||
memory=memory_data,
|
||
mcp_servers=mcp_servers,
|
||
telemetry=telemetry_data,
|
||
compression=compression_data,
|
||
session=session_data,
|
||
bus=server.get("bus"),
|
||
marketplace=marketplace_data,
|
||
alignment=alignment_data,
|
||
router=router_data,
|
||
usage_store=usage_store_data,
|
||
cascade_store=cascade_store_data,
|
||
evolution=evolution_data,
|
||
expert_paths=expert_paths,
|
||
board=board_data,
|
||
prompt_cache=prompt_cache_data,
|
||
streaming=streaming_data,
|
||
verification=verification_data,
|
||
)
|
||
|
||
@staticmethod
|
||
def _build_llm_config(data: dict) -> LLMConfig:
|
||
"""Build LLMConfig from the llm section of agentkit.yaml."""
|
||
from agentkit.llm.config import CacheConfig
|
||
|
||
providers = {}
|
||
model_aliases = {}
|
||
|
||
for name, pconf in data.get("providers", {}).items():
|
||
api_key = pconf.get("api_key", "")
|
||
base_url = pconf.get("base_url", "")
|
||
models = pconf.get("models", {})
|
||
|
||
# Build model aliases from alias fields within model configs
|
||
for model_name, model_conf in models.items():
|
||
alias = model_conf.get("alias") if isinstance(model_conf, dict) else None
|
||
if alias:
|
||
model_aliases[alias] = f"{name}/{model_name}"
|
||
|
||
providers[name] = ProviderConfig(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
models=models,
|
||
type=pconf.get("type", "openai"),
|
||
max_tokens=pconf.get("max_tokens", 4096),
|
||
timeout=pconf.get("timeout", 120.0),
|
||
max_connections=pconf.get("max_connections", 100),
|
||
max_keepalive_connections=pconf.get("max_keepalive_connections", 20),
|
||
keepalive_expiry=pconf.get("keepalive_expiry", 30.0),
|
||
)
|
||
|
||
# Merge top-level model_aliases from YAML (takes precedence over inline alias fields)
|
||
top_level_aliases = data.get("model_aliases", {})
|
||
if isinstance(top_level_aliases, dict):
|
||
model_aliases.update(top_level_aliases)
|
||
|
||
# Build CacheConfig if cache section is present
|
||
cache_config = None
|
||
cache_data = data.get("cache")
|
||
if cache_data and isinstance(cache_data, dict):
|
||
cache_config = CacheConfig.from_dict(cache_data)
|
||
|
||
return LLMConfig(
|
||
providers=providers,
|
||
model_aliases=model_aliases,
|
||
fallbacks=data.get("fallbacks", {}),
|
||
cache=cache_config,
|
||
)
|
||
|
||
@staticmethod
|
||
def _build_mcp_configs(data: dict) -> dict[str, MCPServerConfig]:
|
||
"""Build MCP server configs from the mcp section of agentkit.yaml."""
|
||
servers = data.get("servers", {})
|
||
if not servers:
|
||
return {}
|
||
result = {}
|
||
for name, server_conf in servers.items():
|
||
if isinstance(server_conf, dict):
|
||
result[name] = MCPServerConfig.from_dict(server_conf)
|
||
return result
|
||
|
||
def load_skill_configs(self) -> list[SkillConfig]:
|
||
"""Load all SkillConfig from configured skill paths."""
|
||
configs = []
|
||
for skill_path in self.skill_paths:
|
||
path = Path(skill_path)
|
||
if path.is_file() and path.suffix in (".yaml", ".yml"):
|
||
try:
|
||
config = SkillConfig.from_yaml(str(path))
|
||
configs.append(config)
|
||
logger.info(f"Loaded skill config: {config.name} from {path}")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load skill config from {path}: {e}")
|
||
elif path.is_dir():
|
||
for yaml_file in sorted(path.glob("*.yaml")):
|
||
try:
|
||
config = SkillConfig.from_yaml(str(yaml_file))
|
||
configs.append(config)
|
||
logger.info(f"Loaded skill config: {config.name} from {yaml_file}")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load skill config from {yaml_file}: {e}")
|
||
for yaml_file in sorted(path.glob("*.yml")):
|
||
try:
|
||
config = SkillConfig.from_yaml(str(yaml_file))
|
||
configs.append(config)
|
||
logger.info(f"Loaded skill config: {config.name} from {yaml_file}")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load skill config from {yaml_file}: {e}")
|
||
return configs
|
||
|
||
def watch_config(self, config_path: str | None = None) -> None:
|
||
"""Start watching the config file for changes and hot-reload.
|
||
|
||
Uses watchfiles if available, otherwise falls back to asyncio polling
|
||
(checks mtime every 30 seconds).
|
||
|
||
Args:
|
||
config_path: Path to the config file. If None, uses the path
|
||
from the last from_yaml() call.
|
||
"""
|
||
path = config_path or self._config_path
|
||
if not path:
|
||
logger.warning("No config path specified for watching")
|
||
return
|
||
|
||
self._config_path = path
|
||
if not self._last_mtime:
|
||
try:
|
||
self._last_mtime = os.path.getmtime(path)
|
||
except OSError:
|
||
self._last_mtime = 0.0
|
||
|
||
try:
|
||
import watchfiles # noqa: F401
|
||
|
||
self._watcher_task = asyncio.ensure_future(self._watch_with_watchfiles(path))
|
||
logger.info(f"Config watcher started (watchfiles) for {path}")
|
||
except ImportError:
|
||
self._watcher_task = asyncio.ensure_future(self._poll_config_loop(path))
|
||
logger.info(f"Config watcher started (polling) for {path}")
|
||
|
||
def stop_watching(self) -> None:
|
||
"""Stop watching the config file."""
|
||
if self._watcher_task is not None and not self._watcher_task.done():
|
||
self._watcher_task.cancel()
|
||
logger.info("Config watcher stopped")
|
||
self._watcher_task = None
|
||
|
||
async def _watch_with_watchfiles(self, path: str) -> None:
|
||
"""Watch config file using watchfiles library."""
|
||
try:
|
||
from watchfiles import awatch
|
||
|
||
async for changes in awatch(path):
|
||
for change_type, changed_path in changes:
|
||
logger.info(f"Config file change detected: {change_type} on {changed_path}")
|
||
self._try_reload_config(path)
|
||
except asyncio.CancelledError:
|
||
pass
|
||
except Exception as e:
|
||
logger.error(f"watchfiles error, falling back to polling: {e}")
|
||
self._watcher_task = asyncio.ensure_future(self._poll_config_loop(path))
|
||
|
||
async def _poll_config_loop(self, path: str) -> None:
|
||
"""Fallback: poll config file mtime every 30 seconds."""
|
||
try:
|
||
while True:
|
||
await asyncio.sleep(30)
|
||
try:
|
||
current_mtime = os.path.getmtime(path)
|
||
except OSError:
|
||
continue
|
||
if current_mtime != self._last_mtime:
|
||
logger.info(f"Config file change detected (mtime) for {path}")
|
||
self._last_mtime = current_mtime
|
||
self._try_reload_config(path)
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
def _try_reload_config(self, path: str) -> None:
|
||
"""Attempt to reload config from file. On failure, keep current config."""
|
||
try:
|
||
new_config = ServerConfig.from_yaml(path)
|
||
except Exception as e:
|
||
logger.error(f"Failed to reload config from {path}: {e}. Keeping current config.")
|
||
return
|
||
|
||
# Validate basic structure: must have at least a server or llm section
|
||
if not hasattr(new_config, "host") or not hasattr(new_config, "llm_config"):
|
||
logger.error(f"Invalid config structure in {path}. Keeping current config.")
|
||
return
|
||
|
||
# Apply new values
|
||
self.host = new_config.host
|
||
self.port = new_config.port
|
||
self.workers = new_config.workers
|
||
self.api_key = new_config.api_key
|
||
self.rate_limit = new_config.rate_limit
|
||
self.llm_config = new_config.llm_config
|
||
self.skill_paths = new_config.skill_paths
|
||
self.auto_discover_skills = new_config.auto_discover_skills
|
||
self.log_level = new_config.log_level
|
||
self.log_format = new_config.log_format
|
||
self.task_store = new_config.task_store
|
||
self.cors_origins = new_config.cors_origins
|
||
self.memory = new_config.memory
|
||
self.mcp_servers = new_config.mcp_servers
|
||
self.telemetry = new_config.telemetry
|
||
self.compression = new_config.compression
|
||
self.session = new_config.session
|
||
self.marketplace = new_config.marketplace
|
||
self.alignment = new_config.alignment
|
||
self.router = new_config.router
|
||
self.expert_paths = new_config.expert_paths
|
||
self.board = new_config.board
|
||
self._last_mtime = new_config._last_mtime
|
||
|
||
logger.info(f"Config reloaded from {path}")
|
||
|
||
if self.on_change is not None:
|
||
try:
|
||
self.on_change(self)
|
||
except Exception as e:
|
||
logger.error(f"Config on_change callback error: {e}")
|
||
|
||
|
||
# ── .env loading ───────────────────────────────────────────────────────
|
||
|
||
_ALLOWED_ENV_PREFIXES = (
|
||
"AGENTKIT_",
|
||
"DASHSCOPE_",
|
||
"OPENAI_",
|
||
"ANTHROPIC_",
|
||
"GEMINI_",
|
||
"TAVILY_",
|
||
"SERPER_",
|
||
"DEEPSEEK_",
|
||
"DOUBAO_",
|
||
)
|
||
_ALLOWED_ENV_EXACT = {"DATABASE_URL", "REDIS_URL"}
|
||
|
||
|
||
def load_dotenv(
|
||
dotenv_path: str | Path,
|
||
*,
|
||
allowed_prefixes: tuple[str, ...] | None = None,
|
||
allowed_exact: set[str] | None = None,
|
||
) -> None:
|
||
"""Load environment variables from a .env file.
|
||
|
||
Only variables matching allowed prefixes or exact names are loaded.
|
||
Existing environment variables are never overwritten.
|
||
|
||
Args:
|
||
dotenv_path: Path to the .env file.
|
||
allowed_prefixes: Env var prefixes to allow. None = allow all.
|
||
allowed_exact: Exact env var names to allow. None = allow all.
|
||
"""
|
||
path = Path(dotenv_path)
|
||
if not path.exists():
|
||
return
|
||
|
||
prefixes = allowed_prefixes
|
||
exact = allowed_exact
|
||
|
||
with open(path, encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line or line.startswith("#"):
|
||
continue
|
||
if "=" not in line:
|
||
continue
|
||
key, _, value = line.partition("=")
|
||
key = key.strip()
|
||
value = value.strip().strip("\"'")
|
||
if not key or key in os.environ:
|
||
continue
|
||
# Apply allowlist if provided
|
||
if prefixes is not None or exact is not None:
|
||
allowed = False
|
||
if prefixes and any(key.startswith(p) for p in prefixes):
|
||
allowed = True
|
||
if exact and key in exact:
|
||
allowed = True
|
||
if not allowed:
|
||
logger.warning(f"Skipping .env variable '{key}' (not in allowed list)")
|
||
continue
|
||
os.environ[key] = value
|
||
|
||
|
||
def load_config_with_dotenv(config_path: str | Path) -> ServerConfig:
|
||
"""Load ServerConfig with .env resolution (production-standard loading).
|
||
|
||
1. Load .env from config file's parent directory
|
||
2. Parse agentkit.yaml with env vars resolved
|
||
|
||
This is the canonical way to load config in all CLI commands and app factory.
|
||
"""
|
||
config_path = str(config_path)
|
||
dotenv = Path(config_path).parent / ".env"
|
||
load_dotenv(dotenv)
|
||
return ServerConfig.from_yaml(config_path)
|
||
|
||
|
||
def find_config_path(config_arg: str | None = None) -> str | None:
|
||
"""Find the agentkit.yaml config file.
|
||
|
||
Priority:
|
||
1. Explicit --config argument
|
||
2. ./agentkit.yaml in current directory
|
||
3. ~/.agentkit/agentkit.yaml in home directory
|
||
"""
|
||
if config_arg:
|
||
if Path(config_arg).exists():
|
||
return config_arg
|
||
logger.warning(f"Config file not found: {config_arg}")
|
||
return None
|
||
|
||
# Check current directory
|
||
cwd_config = Path.cwd() / DEFAULT_CONFIG_FILE
|
||
if cwd_config.exists():
|
||
return str(cwd_config)
|
||
|
||
# Check home directory
|
||
home_config = Path.home() / ".agentkit" / DEFAULT_CONFIG_FILE
|
||
if home_config.exists():
|
||
return str(home_config)
|
||
|
||
return None
|