fischer-agentkit/src/agentkit/server/admin/skill_service.py

429 lines
16 KiB
Python

"""SkillService — admin-driven skill enable/disable + import/reload (U6).
This module is the single owner of the ``skill_states`` table (admin
disable markers) and the YAML import/reload pipeline. Web UI routes
(``/api/v1/admin/skills/*``) and the existing skill-management route
both call into :class:`SkillService` rather than touching the table
or filesystem directly, keeping validation rules (skill-name regex,
path-traversal guard, YAML validation) in one place.
The service is a module-level singleton (see :func:`get_skill_service`)
so tests can inject a custom instance via :func:`set_skill_service`.
Design notes
------------
- Each method opens its own short-lived :class:`aiosqlite.Connection`
(mirrors :class:`DepartmentService`).
- ``import_skill`` reuses the validation regex from
:mod:`agentkit.server.routes.skills` but inlines a copy to avoid a
circular import (routes.skills → admin.skill_service → routes.skills).
- ``reload_skill`` and ``update_skill_config`` accept a
:class:`SkillRegistry` instance so the caller can pass the live
registry from ``app.state``.
"""
from __future__ import annotations
import logging
import os
import re
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any
import aiosqlite
import yaml
from agentkit.server.auth.models import skill_state_row_to_dict
if TYPE_CHECKING:
from agentkit.skills.registry import SkillRegistry
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants & helpers
# ---------------------------------------------------------------------------
# Strict skill name validation: lowercase alphanumeric, hyphens, underscores.
# Inlined from routes/skills.py to avoid a circular import.
_SKILL_NAME_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
def _now_iso() -> str:
"""Return current UTC time as ISO 8601 string."""
return datetime.now(timezone.utc).isoformat()
def _validate_skill_name(name: str) -> str:
"""Validate and normalize a skill name. Raises ``ValueError`` on invalid input.
Mirrors the regex in :mod:`agentkit.server.routes.skills` but raises
``ValueError`` (not ``HTTPException``) so the service is HTTP-agnostic.
"""
if not isinstance(name, str):
raise ValueError("Skill name must be a string")
normalized = name.strip().lower()
if not _SKILL_NAME_RE.match(normalized):
raise ValueError(
f"Invalid skill name {name!r}: must contain only lowercase letters, "
f"digits, hyphens, and underscores (1-64 chars)"
)
return normalized
def _validate_yaml_content(content: str) -> dict[str, Any]:
"""Validate YAML content. Returns parsed dict. Raises ``ValueError`` on invalid input.
Mirrors the validation in :mod:`agentkit.server.routes.skills` but
raises ``ValueError`` (not ``HTTPException``).
"""
try:
data = yaml.safe_load(content)
except yaml.YAMLError as exc:
raise ValueError(f"Invalid YAML content: {exc}") from exc
if not isinstance(data, dict):
raise ValueError("Skill YAML must be a mapping/dict")
if "name" not in data:
raise ValueError("Skill YAML must contain a 'name' field")
return data
def _is_within_directory(file_path: str, directory: str) -> bool:
"""Return ``True`` if ``file_path`` resolves inside ``directory``.
Uses :func:`os.path.realpath` to canonicalize both paths before
comparison, so symlink traversal and ``..`` segments are detected.
"""
real_file = os.path.realpath(file_path)
real_dir = os.path.realpath(directory)
return real_file.startswith(real_dir + os.sep) or real_file == real_dir
# ---------------------------------------------------------------------------
# Service
# ---------------------------------------------------------------------------
class SkillService:
"""Admin-driven skill enable/disable + import/reload operations.
All DB-touching methods are async and take ``db_path: Path`` as the
first argument. YAML-touching methods take ``skills_dir: str`` and
optionally a :class:`SkillRegistry` instance for live reload.
"""
# ------------------------------------------------------------------
# Enable / disable (DB state)
# ------------------------------------------------------------------
async def disable_skill(
self,
db_path: Path,
skill_name: str,
disabled_by: str | None = None,
) -> dict[str, Any]:
"""Mark a skill as disabled in the DB.
Idempotent: if the skill is already disabled, the row is updated
(refreshing ``disabled_at`` and ``disabled_by``).
Args:
db_path: Path to the auth SQLite DB.
skill_name: Skill name to disable (validated against the regex).
disabled_by: Optional admin user id (audit trail).
Returns:
The disabled skill state dict.
"""
normalized = _validate_skill_name(skill_name)
now = _now_iso()
async with aiosqlite.connect(str(db_path)) as db:
db.row_factory = aiosqlite.Row
await db.execute(
"INSERT INTO skill_states (skill_name, is_disabled, disabled_at, disabled_by) "
"VALUES (?, 1, ?, ?) "
"ON CONFLICT(skill_name) DO UPDATE SET "
" is_disabled = excluded.is_disabled, "
" disabled_at = excluded.disabled_at, "
" disabled_by = excluded.disabled_by",
(normalized, now, disabled_by),
)
await db.commit()
cursor = await db.execute(
"SELECT * FROM skill_states WHERE skill_name = ?",
(normalized,),
)
row = await cursor.fetchone()
assert row is not None # we just upserted it
return skill_state_row_to_dict(row)
async def enable_skill(
self,
db_path: Path,
skill_name: str,
) -> bool:
"""Remove the disabled mark for a skill.
Returns:
``True`` if a row was deleted (skill was previously disabled),
``False`` if the skill was not disabled.
"""
normalized = _validate_skill_name(skill_name)
async with aiosqlite.connect(str(db_path)) as db:
cursor = await db.execute(
"DELETE FROM skill_states WHERE skill_name = ?",
(normalized,),
)
await db.commit()
return cursor.rowcount > 0
async def is_skill_disabled(
self,
db_path: Path,
skill_name: str,
) -> bool:
"""Return ``True`` if the skill is currently disabled."""
normalized = _validate_skill_name(skill_name)
async with aiosqlite.connect(str(db_path)) as db:
cursor = await db.execute(
"SELECT is_disabled FROM skill_states WHERE skill_name = ?",
(normalized,),
)
row = await cursor.fetchone()
if row is None:
return False
return bool(row[0])
async def list_disabled_skills(
self,
db_path: Path,
) -> list[str]:
"""Return the list of skill names currently disabled."""
async with aiosqlite.connect(str(db_path)) as db:
cursor = await db.execute(
"SELECT skill_name FROM skill_states WHERE is_disabled = 1 ORDER BY skill_name ASC"
)
rows = await cursor.fetchall()
return [row[0] for row in rows]
# ------------------------------------------------------------------
# YAML import / reload / update
# ------------------------------------------------------------------
async def import_skill(
self,
yaml_content: str,
skills_dir: str,
skill_registry: SkillRegistry | None = None,
tool_registry: Any = None,
) -> dict[str, Any]:
"""Validate YAML, write to file, and load into the registry.
Args:
yaml_content: Raw YAML text for the skill config.
skills_dir: Target directory for the YAML file. The file is
written as ``{skills_dir}/{skill_name}.yaml``.
skill_registry: Optional :class:`SkillRegistry` to register
the loaded skill into. If ``None``, the skill is only
written to disk.
tool_registry: Optional :class:`ToolRegistry` for tool binding.
Returns:
Dict with ``name`` and ``path`` of the imported skill.
Raises:
ValueError: If the YAML is invalid, the skill name is invalid,
or the resolved path escapes ``skills_dir``.
"""
data = _validate_yaml_content(yaml_content)
raw_name = data["name"]
skill_name = _validate_skill_name(raw_name)
os.makedirs(skills_dir, exist_ok=True)
file_path = os.path.join(skills_dir, f"{skill_name}.yaml")
# Path traversal protection: the resolved path must stay inside
# skills_dir. (skill_name is regex-validated so this should always
# pass, but we double-check defensively.)
if not _is_within_directory(file_path, skills_dir):
raise ValueError(f"Resolved skill path escapes skills_dir: {file_path}")
with open(file_path, "w", encoding="utf-8") as f:
f.write(yaml_content)
if skill_registry is not None:
try:
from agentkit.skills.loader import SkillLoader
loader = SkillLoader(
skill_registry=skill_registry,
tool_registry=tool_registry,
)
loader.load_from_file(file_path)
except Exception as exc:
# Remove the invalid YAML file and re-raise as ValueError
# so the route layer maps it to 400.
try:
os.remove(file_path)
except OSError:
pass
raise ValueError(f"Skill YAML written but registration failed: {exc}") from exc
return {"name": skill_name, "path": file_path}
async def reload_skill(
self,
skill_name: str,
skill_registry: SkillRegistry,
skills_dir: str,
tool_registry: Any = None,
) -> dict[str, Any]:
"""Unregister and reload a skill from its YAML file.
Args:
skill_name: Skill name to reload (validated against the regex).
skill_registry: The live :class:`SkillRegistry` instance.
skills_dir: Directory containing the YAML file.
tool_registry: Optional :class:`ToolRegistry` for tool binding.
Returns:
Dict with ``name``, ``path``, and ``status`` of the reload.
Raises:
ValueError: If the skill name is invalid or the YAML file
cannot be found.
"""
normalized = _validate_skill_name(skill_name)
file_path = os.path.join(skills_dir, f"{normalized}.yaml")
if not os.path.isfile(file_path):
raise ValueError(f"Skill {normalized!r} not found: no YAML file at {file_path}")
if not _is_within_directory(file_path, skills_dir):
raise ValueError(f"Resolved skill path escapes skills_dir: {file_path}")
# Unregister the existing skill (if any) before reloading so the
# registry picks up the new config cleanly.
try:
skill_registry.unregister(normalized)
except Exception: # noqa: BLE001 — unregister is best-effort
logger.debug("Skill %r was not registered before reload", normalized)
from agentkit.skills.loader import SkillLoader
loader = SkillLoader(
skill_registry=skill_registry,
tool_registry=tool_registry,
)
loader.load_from_file(file_path)
return {
"name": normalized,
"path": file_path,
"status": "reloaded",
}
async def update_skill_config(
self,
skill_name: str,
config_patch: dict[str, Any],
skills_dir: str,
skill_registry: SkillRegistry,
tool_registry: Any = None,
) -> dict[str, Any]:
"""Update a skill's YAML config in-place and reload it.
Performs a shallow merge: top-level keys in ``config_patch``
overwrite the existing YAML values. Nested dicts are replaced
wholesale (not deep-merged) to keep the semantics predictable.
Args:
skill_name: Skill name to update (validated against the regex).
config_patch: Dict of top-level YAML keys to update.
skills_dir: Directory containing the YAML file.
skill_registry: The live :class:`SkillRegistry` instance.
tool_registry: Optional :class:`ToolRegistry` for tool binding.
Returns:
Dict with ``name``, ``path``, and ``status`` of the update.
Raises:
ValueError: If the skill name is invalid, the YAML file
cannot be found, or the patched config is invalid.
"""
normalized = _validate_skill_name(skill_name)
file_path = os.path.join(skills_dir, f"{normalized}.yaml")
if not os.path.isfile(file_path):
raise ValueError(f"Skill {normalized!r} not found: no YAML file at {file_path}")
if not _is_within_directory(file_path, skills_dir):
raise ValueError(f"Resolved skill path escapes skills_dir: {file_path}")
# Read existing YAML, apply patch, validate, write back.
with open(file_path, encoding="utf-8") as f:
existing = yaml.safe_load(f) or {}
if not isinstance(existing, dict):
raise ValueError(f"Existing skill YAML at {file_path} is not a mapping")
# Shallow merge: top-level keys in config_patch overwrite.
# The 'name' field is preserved (cannot rename via patch).
patched = {**existing, **config_patch, "name": existing.get("name", normalized)}
# Validate the patched config by round-tripping through YAML.
try:
patched_yaml = yaml.safe_dump(patched, default_flow_style=False, allow_unicode=True)
except yaml.YAMLError as exc:
raise ValueError(f"Failed to serialize patched config: {exc}") from exc
_validate_yaml_content(patched_yaml)
with open(file_path, "w", encoding="utf-8") as f:
f.write(patched_yaml)
# Reload the skill into the registry.
try:
skill_registry.unregister(normalized)
except Exception: # noqa: BLE001 — unregister is best-effort
logger.debug("Skill %r was not registered before update", normalized)
from agentkit.skills.loader import SkillLoader
loader = SkillLoader(
skill_registry=skill_registry,
tool_registry=tool_registry,
)
loader.load_from_file(file_path)
return {
"name": normalized,
"path": file_path,
"status": "updated",
}
# ---------------------------------------------------------------------------
# Module-level singleton (overridable in tests via set_skill_service)
# ---------------------------------------------------------------------------
_skill_service: SkillService | None = None
def get_skill_service() -> SkillService:
"""Return the process-wide :class:`SkillService` (lazy singleton)."""
global _skill_service
if _skill_service is None:
_skill_service = SkillService()
return _skill_service
def set_skill_service(service: SkillService | None) -> None:
"""Inject a custom :class:`SkillService` (used by tests)."""
global _skill_service
_skill_service = service