fischer-agentkit/src/agentkit/server/config.py

585 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Server configuration loader - loads agentkit.yaml and .env"""
import asyncio
import logging
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable
import yaml
from agentkit.llm.config import LLMConfig, ProviderConfig
from agentkit.skills.base import SkillConfig
logger = logging.getLogger(__name__)
# Default config file name
DEFAULT_CONFIG_FILE = "agentkit.yaml"
@dataclass
class MCPServerConfig:
"""Configuration for a single MCP Server connection"""
transport: str # "stdio" | "streamable_http" | "sse"
# stdio-specific
command: str | None = None
args: list[str] | None = None
env: dict[str, str] | None = None
# http/sse-specific
url: str | None = None
headers: dict[str, str] | None = None
# common
timeout: float = 30.0
def validate(self) -> None:
"""Validate configuration, raise ValueError if invalid"""
if self.transport not in ("stdio", "streamable_http", "sse"):
raise ValueError(f"Invalid transport: {self.transport}")
if self.transport == "stdio" and not self.command:
raise ValueError("stdio transport requires 'command'")
if self.transport in ("streamable_http", "sse") and not self.url:
raise ValueError(f"{self.transport} transport requires 'url'")
@classmethod
def from_dict(cls, data: dict) -> "MCPServerConfig":
"""Create from dict (parsed from YAML)"""
return cls(
transport=data.get("transport", "stdio"),
command=data.get("command"),
args=data.get("args"),
env=data.get("env"),
url=data.get("url"),
headers=data.get("headers"),
timeout=data.get("timeout", 30.0),
)
def _resolve_env_vars(value: Any) -> Any:
"""Resolve ${VAR:-default} patterns in string values from environment variables."""
if not isinstance(value, str):
return value
pattern = re.compile(r"\$\{([^}]+)\}")
def replacer(match):
expr = match.group(1)
if ":-" in expr:
var_name, default = expr.split(":-", 1)
return os.environ.get(var_name, default)
return os.environ.get(expr, match.group(0))
return pattern.sub(replacer, value)
def _deep_resolve(data: Any) -> Any:
"""Recursively resolve env vars in nested dicts/lists."""
if isinstance(data, dict):
return {k: _deep_resolve(v) for k, v in data.items()}
if isinstance(data, list):
return [_deep_resolve(item) for item in data]
if isinstance(data, str):
return _resolve_env_vars(data)
return data
class ServerConfig:
"""Server configuration loaded from agentkit.yaml"""
def __init__(
self,
host: str = "0.0.0.0",
port: int = 8001,
workers: int = 1,
api_key: str | None = None,
rate_limit: int = 60,
llm_config: LLMConfig | None = None,
skill_paths: list[str] | None = None,
auto_discover_skills: bool = True,
log_level: str = "INFO",
log_format: str = "text",
task_store: dict[str, Any] | None = None,
cors_origins: list[str] | None = None,
memory: dict[str, Any] | None = None,
mcp_servers: dict[str, MCPServerConfig] | None = None,
telemetry: dict[str, Any] | None = None,
compression: dict[str, Any] | None = None,
session: dict[str, Any] | None = None,
bus: dict[str, Any] | None = None,
marketplace: dict[str, Any] | None = None,
alignment: dict[str, Any] | None = None,
router: dict[str, Any] | None = None,
usage_store: dict[str, Any] | None = None,
cascade_store: dict[str, Any] | None = None,
evolution: dict[str, Any] | None = None,
expert_paths: list[str] | None = None,
board: dict[str, Any] | None = None,
prompt_cache: dict[str, Any] | None = None,
streaming: dict[str, Any] | None = None,
verification: dict[str, Any] | None = None,
on_change: Callable[["ServerConfig"], None] | None = None,
):
self.host = host
self.port = port
self.workers = workers
self.api_key = api_key
self.rate_limit = rate_limit
self.llm_config = llm_config or LLMConfig()
self.skill_paths = skill_paths or []
self.auto_discover_skills = auto_discover_skills
self.log_level = log_level
self.log_format = log_format
self.task_store = task_store or {}
self.cors_origins = cors_origins or ["*"]
self.memory = memory or {}
self.mcp_servers = mcp_servers or {}
self.telemetry = telemetry or {}
self.compression = compression or {}
self.session = session or {}
self.bus = bus or {}
self.marketplace = marketplace or {}
self.alignment = alignment or {}
self.router = router or {}
self.usage_store = usage_store or {}
self.cascade_store = cascade_store or {}
self.evolution = evolution or {}
self.expert_paths = expert_paths or []
self.board = board or {}
self.prompt_cache = prompt_cache or {}
# U3/G8: streaming.flush_interval_ms 控制 token chunk 节流(默认 0 = 逐 chunk yield)
self.streaming = streaming or {}
# U4/G1: verification.max_reinjections 控制 verify 失败回灌次数(默认 1)
# verification_enabled=False 时此配置无效
self.verification = verification or {}
self.on_change = on_change
# Config watching state
self._config_path: str | None = None
self._watcher_task: asyncio.Task | None = None
self._last_mtime: float = 0.0
def has_llm_provider(self) -> bool:
"""检查是否配置了有效的 LLM ProviderAPI Key 已解析且非空)"""
for name, provider in self.llm_config.providers.items():
if provider.api_key and not provider.api_key.startswith("${"):
return True
return False
@classmethod
def from_yaml(cls, path: str) -> "ServerConfig":
"""Load configuration from a YAML file."""
with open(path, encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
# Resolve environment variables
data = _deep_resolve(data)
config = cls.from_dict(data)
config._config_path = path
config._last_mtime = os.path.getmtime(path)
return config
@classmethod
def from_dict(cls, data: dict) -> "ServerConfig":
"""Create ServerConfig from a dictionary."""
server = data.get("server", {})
llm_data = data.get("llm", {})
skills_data = data.get("skills", {})
logging_data = data.get("logging", {})
task_store_data = data.get("task_store", {})
memory_data = data.get("memory", {})
mcp_data = data.get("mcp", {})
# Build LLMConfig
llm_config = cls._build_llm_config(llm_data)
# Build skill paths
skill_paths = skills_data.get("paths", [])
auto_discover = skills_data.get("auto_discover", True)
# Build MCP server configs
mcp_servers = cls._build_mcp_configs(mcp_data)
# Telemetry config
telemetry_data = data.get("telemetry", {})
# Compression config
compression_data = data.get("compression", {})
# Session config
session_data = data.get("session", {})
# Marketplace config
marketplace_data = data.get("marketplace", {})
# Alignment config
alignment_data = data.get("alignment", {})
# Router config
router_data = data.get("router", {})
# Usage store config
usage_store_data = data.get("usage_store", {})
# Cascade store config
cascade_store_data = data.get("cascade_store", {})
# Evolution store config
evolution_data = data.get("evolution", {})
# Expert templates config (paths to YAML files defining ExpertTemplates)
experts_data = data.get("experts", {})
expert_paths = experts_data.get("paths", [])
# Board meeting config (max_rounds, default_template, etc.)
board_data = data.get("board", {})
# U2/U3/U4: prompt_cache / streaming / verification 配置(从 YAML 读取)
prompt_cache_data = data.get("prompt_cache", {})
streaming_data = data.get("streaming", {})
verification_data = data.get("verification", {})
return cls(
host=server.get("host", "0.0.0.0"),
port=server.get("port", 8001),
workers=server.get("workers", 1),
api_key=server.get("api_key"),
rate_limit=server.get("rate_limit", 60),
llm_config=llm_config,
skill_paths=skill_paths,
auto_discover_skills=auto_discover,
log_level=logging_data.get("level", "INFO"),
log_format=logging_data.get("format", "text"),
task_store=task_store_data,
cors_origins=server.get("cors_origins"),
memory=memory_data,
mcp_servers=mcp_servers,
telemetry=telemetry_data,
compression=compression_data,
session=session_data,
bus=server.get("bus"),
marketplace=marketplace_data,
alignment=alignment_data,
router=router_data,
usage_store=usage_store_data,
cascade_store=cascade_store_data,
evolution=evolution_data,
expert_paths=expert_paths,
board=board_data,
prompt_cache=prompt_cache_data,
streaming=streaming_data,
verification=verification_data,
)
@staticmethod
def _build_llm_config(data: dict) -> LLMConfig:
"""Build LLMConfig from the llm section of agentkit.yaml."""
from agentkit.llm.config import CacheConfig
providers = {}
model_aliases = {}
for name, pconf in data.get("providers", {}).items():
api_key = pconf.get("api_key", "")
base_url = pconf.get("base_url", "")
models = pconf.get("models", {})
# Build model aliases from alias fields within model configs
for model_name, model_conf in models.items():
alias = model_conf.get("alias") if isinstance(model_conf, dict) else None
if alias:
model_aliases[alias] = f"{name}/{model_name}"
providers[name] = ProviderConfig(
api_key=api_key,
base_url=base_url,
models=models,
type=pconf.get("type", "openai"),
max_tokens=pconf.get("max_tokens", 4096),
timeout=pconf.get("timeout", 120.0),
max_connections=pconf.get("max_connections", 100),
max_keepalive_connections=pconf.get("max_keepalive_connections", 20),
keepalive_expiry=pconf.get("keepalive_expiry", 30.0),
)
# Merge top-level model_aliases from YAML (takes precedence over inline alias fields)
top_level_aliases = data.get("model_aliases", {})
if isinstance(top_level_aliases, dict):
model_aliases.update(top_level_aliases)
# Build CacheConfig if cache section is present
cache_config = None
cache_data = data.get("cache")
if cache_data and isinstance(cache_data, dict):
cache_config = CacheConfig.from_dict(cache_data)
return LLMConfig(
providers=providers,
model_aliases=model_aliases,
fallbacks=data.get("fallbacks", {}),
cache=cache_config,
)
@staticmethod
def _build_mcp_configs(data: dict) -> dict[str, MCPServerConfig]:
"""Build MCP server configs from the mcp section of agentkit.yaml."""
servers = data.get("servers", {})
if not servers:
return {}
result = {}
for name, server_conf in servers.items():
if isinstance(server_conf, dict):
result[name] = MCPServerConfig.from_dict(server_conf)
return result
def load_skill_configs(self) -> list[SkillConfig]:
"""Load all SkillConfig from configured skill paths."""
configs = []
for skill_path in self.skill_paths:
path = Path(skill_path)
if path.is_file() and path.suffix in (".yaml", ".yml"):
try:
config = SkillConfig.from_yaml(str(path))
configs.append(config)
logger.info(f"Loaded skill config: {config.name} from {path}")
except Exception as e:
logger.warning(f"Failed to load skill config from {path}: {e}")
elif path.is_dir():
for yaml_file in sorted(path.glob("*.yaml")):
try:
config = SkillConfig.from_yaml(str(yaml_file))
configs.append(config)
logger.info(f"Loaded skill config: {config.name} from {yaml_file}")
except Exception as e:
logger.warning(f"Failed to load skill config from {yaml_file}: {e}")
for yaml_file in sorted(path.glob("*.yml")):
try:
config = SkillConfig.from_yaml(str(yaml_file))
configs.append(config)
logger.info(f"Loaded skill config: {config.name} from {yaml_file}")
except Exception as e:
logger.warning(f"Failed to load skill config from {yaml_file}: {e}")
return configs
def watch_config(self, config_path: str | None = None) -> None:
"""Start watching the config file for changes and hot-reload.
Uses watchfiles if available, otherwise falls back to asyncio polling
(checks mtime every 30 seconds).
Args:
config_path: Path to the config file. If None, uses the path
from the last from_yaml() call.
"""
path = config_path or self._config_path
if not path:
logger.warning("No config path specified for watching")
return
self._config_path = path
if not self._last_mtime:
try:
self._last_mtime = os.path.getmtime(path)
except OSError:
self._last_mtime = 0.0
try:
import watchfiles # noqa: F401
self._watcher_task = asyncio.ensure_future(self._watch_with_watchfiles(path))
logger.info(f"Config watcher started (watchfiles) for {path}")
except ImportError:
self._watcher_task = asyncio.ensure_future(self._poll_config_loop(path))
logger.info(f"Config watcher started (polling) for {path}")
def stop_watching(self) -> None:
"""Stop watching the config file."""
if self._watcher_task is not None and not self._watcher_task.done():
self._watcher_task.cancel()
logger.info("Config watcher stopped")
self._watcher_task = None
async def _watch_with_watchfiles(self, path: str) -> None:
"""Watch config file using watchfiles library."""
try:
from watchfiles import awatch
async for changes in awatch(path):
for change_type, changed_path in changes:
logger.info(f"Config file change detected: {change_type} on {changed_path}")
self._try_reload_config(path)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"watchfiles error, falling back to polling: {e}")
self._watcher_task = asyncio.ensure_future(self._poll_config_loop(path))
async def _poll_config_loop(self, path: str) -> None:
"""Fallback: poll config file mtime every 30 seconds."""
try:
while True:
await asyncio.sleep(30)
try:
current_mtime = os.path.getmtime(path)
except OSError:
continue
if current_mtime != self._last_mtime:
logger.info(f"Config file change detected (mtime) for {path}")
self._last_mtime = current_mtime
self._try_reload_config(path)
except asyncio.CancelledError:
pass
def _try_reload_config(self, path: str) -> None:
"""Attempt to reload config from file. On failure, keep current config."""
try:
new_config = ServerConfig.from_yaml(path)
except Exception as e:
logger.error(f"Failed to reload config from {path}: {e}. Keeping current config.")
return
# Validate basic structure: must have at least a server or llm section
if not hasattr(new_config, "host") or not hasattr(new_config, "llm_config"):
logger.error(f"Invalid config structure in {path}. Keeping current config.")
return
# Apply new values
self.host = new_config.host
self.port = new_config.port
self.workers = new_config.workers
self.api_key = new_config.api_key
self.rate_limit = new_config.rate_limit
self.llm_config = new_config.llm_config
self.skill_paths = new_config.skill_paths
self.auto_discover_skills = new_config.auto_discover_skills
self.log_level = new_config.log_level
self.log_format = new_config.log_format
self.task_store = new_config.task_store
self.cors_origins = new_config.cors_origins
self.memory = new_config.memory
self.mcp_servers = new_config.mcp_servers
self.telemetry = new_config.telemetry
self.compression = new_config.compression
self.session = new_config.session
self.marketplace = new_config.marketplace
self.alignment = new_config.alignment
self.router = new_config.router
self.expert_paths = new_config.expert_paths
self.board = new_config.board
self._last_mtime = new_config._last_mtime
logger.info(f"Config reloaded from {path}")
if self.on_change is not None:
try:
self.on_change(self)
except Exception as e:
logger.error(f"Config on_change callback error: {e}")
# ── .env loading ───────────────────────────────────────────────────────
_ALLOWED_ENV_PREFIXES = (
"AGENTKIT_",
"DASHSCOPE_",
"OPENAI_",
"ANTHROPIC_",
"GEMINI_",
"TAVILY_",
"SERPER_",
"DEEPSEEK_",
"DOUBAO_",
)
_ALLOWED_ENV_EXACT = {"DATABASE_URL", "REDIS_URL"}
def load_dotenv(
dotenv_path: str | Path,
*,
allowed_prefixes: tuple[str, ...] | None = None,
allowed_exact: set[str] | None = None,
) -> None:
"""Load environment variables from a .env file.
Only variables matching allowed prefixes or exact names are loaded.
Existing environment variables are never overwritten.
Args:
dotenv_path: Path to the .env file.
allowed_prefixes: Env var prefixes to allow. None = allow all.
allowed_exact: Exact env var names to allow. None = allow all.
"""
path = Path(dotenv_path)
if not path.exists():
return
prefixes = allowed_prefixes
exact = allowed_exact
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" not in line:
continue
key, _, value = line.partition("=")
key = key.strip()
value = value.strip().strip("\"'")
if not key or key in os.environ:
continue
# Apply allowlist if provided
if prefixes is not None or exact is not None:
allowed = False
if prefixes and any(key.startswith(p) for p in prefixes):
allowed = True
if exact and key in exact:
allowed = True
if not allowed:
logger.warning(f"Skipping .env variable '{key}' (not in allowed list)")
continue
os.environ[key] = value
def load_config_with_dotenv(config_path: str | Path) -> ServerConfig:
"""Load ServerConfig with .env resolution (production-standard loading).
1. Load .env from config file's parent directory
2. Parse agentkit.yaml with env vars resolved
This is the canonical way to load config in all CLI commands and app factory.
"""
config_path = str(config_path)
dotenv = Path(config_path).parent / ".env"
load_dotenv(dotenv)
return ServerConfig.from_yaml(config_path)
def find_config_path(config_arg: str | None = None) -> str | None:
"""Find the agentkit.yaml config file.
Priority:
1. Explicit --config argument
2. ./agentkit.yaml in current directory
3. ~/.agentkit/agentkit.yaml in home directory
"""
if config_arg:
if Path(config_arg).exists():
return config_arg
logger.warning(f"Config file not found: {config_arg}")
return None
# Check current directory
cwd_config = Path.cwd() / DEFAULT_CONFIG_FILE
if cwd_config.exists():
return str(cwd_config)
# Check home directory
home_config = Path.home() / ".agentkit" / DEFAULT_CONFIG_FILE
if home_config.exists():
return str(home_config)
return None