fischer-agentkit/src/agentkit/server/config.py

457 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Server configuration loader - loads agentkit.yaml and .env"""
import asyncio
import logging
import os
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable
import yaml
from agentkit.llm.config import LLMConfig, ProviderConfig
from agentkit.skills.base import SkillConfig
logger = logging.getLogger(__name__)
# Default config file name
DEFAULT_CONFIG_FILE = "agentkit.yaml"
@dataclass
class MCPServerConfig:
"""Configuration for a single MCP Server connection"""
transport: str # "stdio" | "streamable_http" | "sse"
# stdio-specific
command: str | None = None
args: list[str] | None = None
env: dict[str, str] | None = None
# http/sse-specific
url: str | None = None
headers: dict[str, str] | None = None
# common
timeout: float = 30.0
def validate(self) -> None:
"""Validate configuration, raise ValueError if invalid"""
if self.transport not in ("stdio", "streamable_http", "sse"):
raise ValueError(f"Invalid transport: {self.transport}")
if self.transport == "stdio" and not self.command:
raise ValueError("stdio transport requires 'command'")
if self.transport in ("streamable_http", "sse") and not self.url:
raise ValueError(f"{self.transport} transport requires 'url'")
@classmethod
def from_dict(cls, data: dict) -> "MCPServerConfig":
"""Create from dict (parsed from YAML)"""
return cls(
transport=data.get("transport", "stdio"),
command=data.get("command"),
args=data.get("args"),
env=data.get("env"),
url=data.get("url"),
headers=data.get("headers"),
timeout=data.get("timeout", 30.0),
)
def _resolve_env_vars(value: Any) -> Any:
"""Resolve ${VAR:-default} patterns in string values from environment variables."""
if not isinstance(value, str):
return value
pattern = re.compile(r"\$\{([^}]+)\}")
def replacer(match):
expr = match.group(1)
if ":-" in expr:
var_name, default = expr.split(":-", 1)
return os.environ.get(var_name, default)
return os.environ.get(expr, match.group(0))
return pattern.sub(replacer, value)
def _deep_resolve(data: Any) -> Any:
"""Recursively resolve env vars in nested dicts/lists."""
if isinstance(data, dict):
return {k: _deep_resolve(v) for k, v in data.items()}
if isinstance(data, list):
return [_deep_resolve(item) for item in data]
if isinstance(data, str):
return _resolve_env_vars(data)
return data
class ServerConfig:
"""Server configuration loaded from agentkit.yaml"""
def __init__(
self,
host: str = "0.0.0.0",
port: int = 8001,
workers: int = 1,
api_key: str | None = None,
rate_limit: int = 60,
llm_config: LLMConfig | None = None,
skill_paths: list[str] | None = None,
auto_discover_skills: bool = True,
log_level: str = "INFO",
log_format: str = "text",
task_store: dict[str, Any] | None = None,
cors_origins: list[str] | None = None,
memory: dict[str, Any] | None = None,
mcp_servers: dict[str, MCPServerConfig] | None = None,
telemetry: dict[str, Any] | None = None,
compression: dict[str, Any] | None = None,
session: dict[str, Any] | None = None,
bus: dict[str, Any] | None = None,
marketplace: dict[str, Any] | None = None,
alignment: dict[str, Any] | None = None,
router: dict[str, Any] | None = None,
on_change: Callable[["ServerConfig"], None] | None = None,
):
self.host = host
self.port = port
self.workers = workers
self.api_key = api_key
self.rate_limit = rate_limit
self.llm_config = llm_config or LLMConfig()
self.skill_paths = skill_paths or []
self.auto_discover_skills = auto_discover_skills
self.log_level = log_level
self.log_format = log_format
self.task_store = task_store or {}
self.cors_origins = cors_origins or ["*"]
self.memory = memory or {}
self.mcp_servers = mcp_servers or {}
self.telemetry = telemetry or {}
self.compression = compression or {}
self.session = session or {}
self.bus = bus or {}
self.marketplace = marketplace or {}
self.alignment = alignment or {}
self.router = router or {}
self.on_change = on_change
# Config watching state
self._config_path: str | None = None
self._watcher_task: asyncio.Task | None = None
self._last_mtime: float = 0.0
def has_llm_provider(self) -> bool:
"""检查是否配置了有效的 LLM ProviderAPI Key 非空)"""
for name, provider in self.llm_config.providers.items():
if provider.api_key:
return True
return False
@classmethod
def from_yaml(cls, path: str) -> "ServerConfig":
"""Load configuration from a YAML file."""
with open(path, encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
# Resolve environment variables
data = _deep_resolve(data)
config = cls.from_dict(data)
config._config_path = path
config._last_mtime = os.path.getmtime(path)
return config
@classmethod
def from_dict(cls, data: dict) -> "ServerConfig":
"""Create ServerConfig from a dictionary."""
server = data.get("server", {})
llm_data = data.get("llm", {})
skills_data = data.get("skills", {})
logging_data = data.get("logging", {})
task_store_data = data.get("task_store", {})
memory_data = data.get("memory", {})
mcp_data = data.get("mcp", {})
# Build LLMConfig
llm_config = cls._build_llm_config(llm_data)
# Build skill paths
skill_paths = skills_data.get("paths", [])
auto_discover = skills_data.get("auto_discover", True)
# Build MCP server configs
mcp_servers = cls._build_mcp_configs(mcp_data)
# Telemetry config
telemetry_data = data.get("telemetry", {})
# Compression config
compression_data = data.get("compression", {})
# Session config
session_data = data.get("session", {})
# Marketplace config
marketplace_data = data.get("marketplace", {})
# Alignment config
alignment_data = data.get("alignment", {})
# Router config
router_data = data.get("router", {})
return cls(
host=server.get("host", "0.0.0.0"),
port=server.get("port", 8001),
workers=server.get("workers", 1),
api_key=server.get("api_key"),
rate_limit=server.get("rate_limit", 60),
llm_config=llm_config,
skill_paths=skill_paths,
auto_discover_skills=auto_discover,
log_level=logging_data.get("level", "INFO"),
log_format=logging_data.get("format", "text"),
task_store=task_store_data,
cors_origins=server.get("cors_origins"),
memory=memory_data,
mcp_servers=mcp_servers,
telemetry=telemetry_data,
compression=compression_data,
session=session_data,
bus=server.get("bus"),
marketplace=marketplace_data,
alignment=alignment_data,
router=router_data,
)
@staticmethod
def _build_llm_config(data: dict) -> LLMConfig:
"""Build LLMConfig from the llm section of agentkit.yaml."""
providers = {}
model_aliases = {}
for name, pconf in data.get("providers", {}).items():
api_key = pconf.get("api_key", "")
base_url = pconf.get("base_url", "")
models = pconf.get("models", {})
# Build model aliases from alias fields
for model_name, model_conf in models.items():
alias = model_conf.get("alias") if isinstance(model_conf, dict) else None
if alias:
model_aliases[alias] = f"{name}/{model_name}"
providers[name] = ProviderConfig(
api_key=api_key,
base_url=base_url,
models=models,
type=pconf.get("type", "openai"),
max_tokens=pconf.get("max_tokens", 4096),
timeout=pconf.get("timeout", 120.0),
)
return LLMConfig(
providers=providers,
model_aliases=model_aliases,
fallbacks=data.get("fallbacks", {}),
)
@staticmethod
def _build_mcp_configs(data: dict) -> dict[str, MCPServerConfig]:
"""Build MCP server configs from the mcp section of agentkit.yaml."""
servers = data.get("servers", {})
if not servers:
return {}
result = {}
for name, server_conf in servers.items():
if isinstance(server_conf, dict):
result[name] = MCPServerConfig.from_dict(server_conf)
return result
def load_skill_configs(self) -> list[SkillConfig]:
"""Load all SkillConfig from configured skill paths."""
configs = []
for skill_path in self.skill_paths:
path = Path(skill_path)
if path.is_file() and path.suffix in (".yaml", ".yml"):
try:
config = SkillConfig.from_yaml(str(path))
configs.append(config)
logger.info(f"Loaded skill config: {config.name} from {path}")
except Exception as e:
logger.warning(f"Failed to load skill config from {path}: {e}")
elif path.is_dir():
for yaml_file in sorted(path.glob("*.yaml")):
try:
config = SkillConfig.from_yaml(str(yaml_file))
configs.append(config)
logger.info(f"Loaded skill config: {config.name} from {yaml_file}")
except Exception as e:
logger.warning(f"Failed to load skill config from {yaml_file}: {e}")
for yaml_file in sorted(path.glob("*.yml")):
try:
config = SkillConfig.from_yaml(str(yaml_file))
configs.append(config)
logger.info(f"Loaded skill config: {config.name} from {yaml_file}")
except Exception as e:
logger.warning(f"Failed to load skill config from {yaml_file}: {e}")
return configs
def load_dotenv(self, dotenv_path: str = ".env") -> None:
"""Load environment variables from a .env file (simple key=value format)."""
path = Path(dotenv_path)
if not path.exists():
return
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" not in line:
continue
key, _, value = line.partition("=")
key = key.strip()
value = value.strip().strip("\"'")
if key and key not in os.environ:
os.environ[key] = value
def watch_config(self, config_path: str | None = None) -> None:
"""Start watching the config file for changes and hot-reload.
Uses watchfiles if available, otherwise falls back to asyncio polling
(checks mtime every 30 seconds).
Args:
config_path: Path to the config file. If None, uses the path
from the last from_yaml() call.
"""
path = config_path or self._config_path
if not path:
logger.warning("No config path specified for watching")
return
self._config_path = path
if not self._last_mtime:
try:
self._last_mtime = os.path.getmtime(path)
except OSError:
self._last_mtime = 0.0
try:
import watchfiles # noqa: F401
self._watcher_task = asyncio.ensure_future(self._watch_with_watchfiles(path))
logger.info(f"Config watcher started (watchfiles) for {path}")
except ImportError:
self._watcher_task = asyncio.ensure_future(self._poll_config_loop(path))
logger.info(f"Config watcher started (polling) for {path}")
def stop_watching(self) -> None:
"""Stop watching the config file."""
if self._watcher_task is not None and not self._watcher_task.done():
self._watcher_task.cancel()
logger.info("Config watcher stopped")
self._watcher_task = None
async def _watch_with_watchfiles(self, path: str) -> None:
"""Watch config file using watchfiles library."""
try:
from watchfiles import awatch
async for changes in awatch(path):
for change_type, changed_path in changes:
logger.info(f"Config file change detected: {change_type} on {changed_path}")
self._try_reload_config(path)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"watchfiles error, falling back to polling: {e}")
self._watcher_task = asyncio.ensure_future(self._poll_config_loop(path))
async def _poll_config_loop(self, path: str) -> None:
"""Fallback: poll config file mtime every 30 seconds."""
try:
while True:
await asyncio.sleep(30)
try:
current_mtime = os.path.getmtime(path)
except OSError:
continue
if current_mtime != self._last_mtime:
logger.info(f"Config file change detected (mtime) for {path}")
self._last_mtime = current_mtime
self._try_reload_config(path)
except asyncio.CancelledError:
pass
def _try_reload_config(self, path: str) -> None:
"""Attempt to reload config from file. On failure, keep current config."""
try:
new_config = ServerConfig.from_yaml(path)
except Exception as e:
logger.error(f"Failed to reload config from {path}: {e}. Keeping current config.")
return
# Validate basic structure: must have at least a server or llm section
if not hasattr(new_config, 'host') or not hasattr(new_config, 'llm_config'):
logger.error(f"Invalid config structure in {path}. Keeping current config.")
return
# Apply new values
self.host = new_config.host
self.port = new_config.port
self.workers = new_config.workers
self.api_key = new_config.api_key
self.rate_limit = new_config.rate_limit
self.llm_config = new_config.llm_config
self.skill_paths = new_config.skill_paths
self.auto_discover_skills = new_config.auto_discover_skills
self.log_level = new_config.log_level
self.log_format = new_config.log_format
self.task_store = new_config.task_store
self.cors_origins = new_config.cors_origins
self.memory = new_config.memory
self.mcp_servers = new_config.mcp_servers
self.telemetry = new_config.telemetry
self.compression = new_config.compression
self.session = new_config.session
self.marketplace = new_config.marketplace
self.alignment = new_config.alignment
self.router = new_config.router
self._last_mtime = new_config._last_mtime
logger.info(f"Config reloaded from {path}")
if self.on_change is not None:
try:
self.on_change(self)
except Exception as e:
logger.error(f"Config on_change callback error: {e}")
def find_config_path(config_arg: str | None = None) -> str | None:
"""Find the agentkit.yaml config file.
Priority:
1. Explicit --config argument
2. ./agentkit.yaml in current directory
3. ~/.agentkit/agentkit.yaml in home directory
"""
if config_arg:
if Path(config_arg).exists():
return config_arg
logger.warning(f"Config file not found: {config_arg}")
return None
# Check current directory
cwd_config = Path.cwd() / DEFAULT_CONFIG_FILE
if cwd_config.exists():
return str(cwd_config)
# Check home directory
home_config = Path.home() / ".agentkit" / DEFAULT_CONFIG_FILE
if home_config.exists():
return str(home_config)
return None