From fd7f6816b83a505a67fc091cce833d82e5da789c Mon Sep 17 00:00:00 2001 From: chiguyong Date: Sun, 21 Jun 2026 16:19:51 +0800 Subject: [PATCH] =?UTF-8?q?feat(admin):=20U6=20=E2=80=94=20Skill=20&=20KB?= =?UTF-8?q?=20management=20endpoints=20+=20department=20binding?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SkillService: enable/disable (persisted in skill_states table, schema v4), import from YAML (with path traversal + name validation), reload from file, update config. GET /skills now filters disabled skills. KbService: list/upload/delete documents with department_id binding. Added department_id field to KnowledgeSource + UploadedDocument. Department visibility: (bound to user depts) ∪ (global = None). 10 new admin endpoints: skill enable/disable/import/reload/update, KB documents CRUD, source sync/rebuild. All guarded by _require_admin. Implemented reload stub in skill_management.py (was no-op). 54 new tests (26 unit + 28 integration). Fixed 4 pre-existing lint errors. 357 admin tests pass, no regressions. --- src/agentkit/server/admin/kb_service.py | 246 +++++++++ src/agentkit/server/admin/skill_service.py | 428 ++++++++++++++ src/agentkit/server/auth/models.py | 49 +- src/agentkit/server/routes/admin.py | 284 ++++++++++ src/agentkit/server/routes/kb_management.py | 19 + .../server/routes/skill_management.py | 44 +- src/agentkit/server/routes/skills.py | 82 ++- .../integration/admin/test_skill_kb_routes.py | 522 ++++++++++++++++++ tests/unit/admin/test_models.py | 30 +- tests/unit/admin/test_skill_service.py | 386 +++++++++++++ tests/unit/auth/test_models.py | 6 +- tests/unit/server/test_skill_management.py | 29 +- 12 files changed, 2070 insertions(+), 55 deletions(-) create mode 100644 src/agentkit/server/admin/kb_service.py create mode 100644 src/agentkit/server/admin/skill_service.py create mode 100644 tests/integration/admin/test_skill_kb_routes.py create mode 100644 tests/unit/admin/test_skill_service.py diff --git a/src/agentkit/server/admin/kb_service.py b/src/agentkit/server/admin/kb_service.py new file mode 100644 index 0000000..1486364 --- /dev/null +++ b/src/agentkit/server/admin/kb_service.py @@ -0,0 +1,246 @@ +"""KbService — admin-driven KB document/source management (U6). + +This module wraps the existing in-memory :class:`KnowledgeSourceStore` +from :mod:`agentkit.server.routes.kb_management` so admin routes can +manage documents and sources through a service layer (mirroring +:class:`DepartmentService` and :class:`SkillService`). + +Department filtering +-------------------- +When listing documents, the service filters by department visibility: +- documents whose ``department_id`` is in the caller's department set, OR +- documents whose ``department_id`` is ``None`` (global documents). + +Admin callers (``department_ids=None``) bypass filtering and see all +documents. + +The service is a module-level singleton (see :func:`get_kb_service`) +so tests can inject a custom instance via :func:`set_kb_service`. +""" + +from __future__ import annotations + +import logging +import uuid +from datetime import datetime, timezone +from typing import Any + +from agentkit.server.routes.kb_management import ( + KnowledgeSource, + KnowledgeSourceStore, + UploadedDocument, + get_source_store, +) + +logger = logging.getLogger(__name__) + + +def _now_iso() -> str: + """Return current UTC time as ISO 8601 string.""" + return datetime.now(timezone.utc).isoformat() + + +class KbService: + """Wraps :class:`KnowledgeSourceStore` with admin-friendly operations. + + The service does NOT own the store — it delegates to the + module-level singleton from :mod:`kb_management`. This keeps a + single source of truth for KB state (the existing routes still + work) while giving admin routes a clean service boundary. + + Department filtering is applied at the service layer so admin + routes can pass ``department_ids`` directly from the + :class:`DepartmentContext` dependency. + """ + + def __init__(self, store: KnowledgeSourceStore | None = None) -> None: + self._store = store + + def _resolve_store(self) -> KnowledgeSourceStore: + """Return the injected store, or the module singleton.""" + if self._store is not None: + return self._store + return get_source_store() + + # ------------------------------------------------------------------ + # Document operations + # ------------------------------------------------------------------ + + def list_documents( + self, + department_ids: list[str] | None = None, + source_id: str | None = None, + ) -> list[dict[str, Any]]: + """List documents, optionally filtered by source and department. + + Args: + department_ids: Caller's department ids. ``None`` means admin + (no filtering — return all documents). An empty list means + a user with no departments (return only global documents). + source_id: Optional source id filter. + + Returns: + List of document dicts (JSON-safe). + """ + store = self._resolve_store() + documents = store.list_documents(source_id=source_id) + + # Admin bypass: department_ids=None means "see everything". + if department_ids is None: + return [self._doc_to_dict(d) for d in documents] + + # Non-admin: visible = (docs in user's departments) ∪ (global docs). + dept_set = set(department_ids) + visible = [d for d in documents if d.department_id is None or d.department_id in dept_set] + return [self._doc_to_dict(d) for d in visible] + + def upload_document( + self, + filename: str, + content: bytes, + source_id: str = "", + department_id: str | None = None, + ) -> dict[str, Any]: + """Upload a document with optional department binding. + + Args: + filename: Original filename. + content: File content bytes (used to estimate chunk count). + source_id: Optional source id. Defaults to ``"local"``. + department_id: Optional department id to bind the document to. + ``None`` means the document is global (visible to all). + + Returns: + The uploaded document dict. + """ + store = self._resolve_store() + effective_source_id = source_id or "local" + + # Ensure a local source exists for default uploads. + if effective_source_id == "local": + local_sources = [s for s in store.list_sources() if s.type == "local"] + if not local_sources: + store.add_source("本地文档", "local", {}) + + # Estimate chunks from content length (rough approximation, + # mirrors the existing /kb-management/documents/upload route). + chunks = max(1, len(content) // 500) + + doc = UploadedDocument( + document_id=str(uuid.uuid4()), + filename=filename, + source_id=effective_source_id, + chunks=chunks, + status="indexed", + department_id=department_id, + ) + store.add_document(doc) + return self._doc_to_dict(doc) + + def delete_document(self, document_id: str) -> bool: + """Delete a document by id. Returns ``True`` if deleted.""" + store = self._resolve_store() + return store.delete_document(document_id) + + # ------------------------------------------------------------------ + # Source operations + # ------------------------------------------------------------------ + + def list_sources(self) -> list[dict[str, Any]]: + """List all KB sources.""" + store = self._resolve_store() + return [self._source_to_dict(s) for s in store.list_sources()] + + def get_source(self, source_id: str) -> dict[str, Any] | None: + """Return a single source by id, or ``None`` if not found.""" + store = self._resolve_store() + source = store.get_source(source_id) + return self._source_to_dict(source) if source else None + + def sync_source(self, source_id: str) -> dict[str, Any]: + """Trigger a source sync (stub — marks status as ``syncing``). + + Raises: + ValueError: If the source does not exist. + """ + store = self._resolve_store() + source = store.get_source(source_id) + if source is None: + raise ValueError(f"KB source {source_id!r} not found") + # Stub: mark status as syncing. A real implementation would + # kick off an async sync job and return a job id. + source.status = "syncing" + source.last_synced = _now_iso() + return { + "source_id": source_id, + "status": "syncing", + "message": "Sync started", + } + + def rebuild_index(self, source_id: str) -> dict[str, Any]: + """Rebuild the index for a source (stub — marks status). + + Raises: + ValueError: If the source does not exist. + """ + store = self._resolve_store() + source = store.get_source(source_id) + if source is None: + raise ValueError(f"KB source {source_id!r} not found") + # Stub: mark status as rebuilding. + source.status = "rebuilding" + return { + "source_id": source_id, + "status": "rebuilding", + "message": "Index rebuild started", + } + + # ------------------------------------------------------------------ + # Serialization helpers + # ------------------------------------------------------------------ + + def _doc_to_dict(self, doc: UploadedDocument) -> dict[str, Any]: + """Convert an :class:`UploadedDocument` to a JSON-safe dict.""" + return { + "document_id": doc.document_id, + "filename": doc.filename, + "source_id": doc.source_id, + "chunks": doc.chunks, + "status": doc.status, + "created_at": doc.created_at, + "department_id": doc.department_id, + } + + def _source_to_dict(self, source: KnowledgeSource) -> dict[str, Any]: + """Convert a :class:`KnowledgeSource` to a JSON-safe dict.""" + return { + "id": source.id, + "name": source.name, + "type": source.type, + "status": source.status, + "document_count": source.document_count, + "last_synced": source.last_synced, + "department_id": source.department_id, + } + + +# --------------------------------------------------------------------------- +# Module-level singleton (overridable in tests via set_kb_service) +# --------------------------------------------------------------------------- + + +_kb_service: KbService | None = None + + +def get_kb_service() -> KbService: + """Return the process-wide :class:`KbService` (lazy singleton).""" + global _kb_service + if _kb_service is None: + _kb_service = KbService() + return _kb_service + + +def set_kb_service(service: KbService | None) -> None: + """Inject a custom :class:`KbService` (used by tests).""" + global _kb_service + _kb_service = service diff --git a/src/agentkit/server/admin/skill_service.py b/src/agentkit/server/admin/skill_service.py new file mode 100644 index 0000000..faee160 --- /dev/null +++ b/src/agentkit/server/admin/skill_service.py @@ -0,0 +1,428 @@ +"""SkillService — admin-driven skill enable/disable + import/reload (U6). + +This module is the single owner of the ``skill_states`` table (admin +disable markers) and the YAML import/reload pipeline. Web UI routes +(``/api/v1/admin/skills/*``) and the existing skill-management route +both call into :class:`SkillService` rather than touching the table +or filesystem directly, keeping validation rules (skill-name regex, +path-traversal guard, YAML validation) in one place. + +The service is a module-level singleton (see :func:`get_skill_service`) +so tests can inject a custom instance via :func:`set_skill_service`. + +Design notes +------------ +- Each method opens its own short-lived :class:`aiosqlite.Connection` + (mirrors :class:`DepartmentService`). +- ``import_skill`` reuses the validation regex from + :mod:`agentkit.server.routes.skills` but inlines a copy to avoid a + circular import (routes.skills → admin.skill_service → routes.skills). +- ``reload_skill`` and ``update_skill_config`` accept a + :class:`SkillRegistry` instance so the caller can pass the live + registry from ``app.state``. +""" + +from __future__ import annotations + +import logging +import os +import re +from datetime import datetime, timezone +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import aiosqlite +import yaml + +from agentkit.server.auth.models import skill_state_row_to_dict + +if TYPE_CHECKING: + from agentkit.skills.registry import SkillRegistry + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Constants & helpers +# --------------------------------------------------------------------------- + +# Strict skill name validation: lowercase alphanumeric, hyphens, underscores. +# Inlined from routes/skills.py to avoid a circular import. +_SKILL_NAME_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$") + + +def _now_iso() -> str: + """Return current UTC time as ISO 8601 string.""" + return datetime.now(timezone.utc).isoformat() + + +def _validate_skill_name(name: str) -> str: + """Validate and normalize a skill name. Raises ``ValueError`` on invalid input. + + Mirrors the regex in :mod:`agentkit.server.routes.skills` but raises + ``ValueError`` (not ``HTTPException``) so the service is HTTP-agnostic. + """ + if not isinstance(name, str): + raise ValueError("Skill name must be a string") + normalized = name.strip().lower() + if not _SKILL_NAME_RE.match(normalized): + raise ValueError( + f"Invalid skill name {name!r}: must contain only lowercase letters, " + f"digits, hyphens, and underscores (1-64 chars)" + ) + return normalized + + +def _validate_yaml_content(content: str) -> dict[str, Any]: + """Validate YAML content. Returns parsed dict. Raises ``ValueError`` on invalid input. + + Mirrors the validation in :mod:`agentkit.server.routes.skills` but + raises ``ValueError`` (not ``HTTPException``). + """ + try: + data = yaml.safe_load(content) + except yaml.YAMLError as exc: + raise ValueError(f"Invalid YAML content: {exc}") from exc + + if not isinstance(data, dict): + raise ValueError("Skill YAML must be a mapping/dict") + + if "name" not in data: + raise ValueError("Skill YAML must contain a 'name' field") + + return data + + +def _is_within_directory(file_path: str, directory: str) -> bool: + """Return ``True`` if ``file_path`` resolves inside ``directory``. + + Uses :func:`os.path.realpath` to canonicalize both paths before + comparison, so symlink traversal and ``..`` segments are detected. + """ + real_file = os.path.realpath(file_path) + real_dir = os.path.realpath(directory) + return real_file.startswith(real_dir + os.sep) or real_file == real_dir + + +# --------------------------------------------------------------------------- +# Service +# --------------------------------------------------------------------------- + + +class SkillService: + """Admin-driven skill enable/disable + import/reload operations. + + All DB-touching methods are async and take ``db_path: Path`` as the + first argument. YAML-touching methods take ``skills_dir: str`` and + optionally a :class:`SkillRegistry` instance for live reload. + """ + + # ------------------------------------------------------------------ + # Enable / disable (DB state) + # ------------------------------------------------------------------ + + async def disable_skill( + self, + db_path: Path, + skill_name: str, + disabled_by: str | None = None, + ) -> dict[str, Any]: + """Mark a skill as disabled in the DB. + + Idempotent: if the skill is already disabled, the row is updated + (refreshing ``disabled_at`` and ``disabled_by``). + + Args: + db_path: Path to the auth SQLite DB. + skill_name: Skill name to disable (validated against the regex). + disabled_by: Optional admin user id (audit trail). + + Returns: + The disabled skill state dict. + """ + normalized = _validate_skill_name(skill_name) + now = _now_iso() + async with aiosqlite.connect(str(db_path)) as db: + db.row_factory = aiosqlite.Row + await db.execute( + "INSERT INTO skill_states (skill_name, is_disabled, disabled_at, disabled_by) " + "VALUES (?, 1, ?, ?) " + "ON CONFLICT(skill_name) DO UPDATE SET " + " is_disabled = excluded.is_disabled, " + " disabled_at = excluded.disabled_at, " + " disabled_by = excluded.disabled_by", + (normalized, now, disabled_by), + ) + await db.commit() + cursor = await db.execute( + "SELECT * FROM skill_states WHERE skill_name = ?", + (normalized,), + ) + row = await cursor.fetchone() + assert row is not None # we just upserted it + return skill_state_row_to_dict(row) + + async def enable_skill( + self, + db_path: Path, + skill_name: str, + ) -> bool: + """Remove the disabled mark for a skill. + + Returns: + ``True`` if a row was deleted (skill was previously disabled), + ``False`` if the skill was not disabled. + """ + normalized = _validate_skill_name(skill_name) + async with aiosqlite.connect(str(db_path)) as db: + cursor = await db.execute( + "DELETE FROM skill_states WHERE skill_name = ?", + (normalized,), + ) + await db.commit() + return cursor.rowcount > 0 + + async def is_skill_disabled( + self, + db_path: Path, + skill_name: str, + ) -> bool: + """Return ``True`` if the skill is currently disabled.""" + normalized = _validate_skill_name(skill_name) + async with aiosqlite.connect(str(db_path)) as db: + cursor = await db.execute( + "SELECT is_disabled FROM skill_states WHERE skill_name = ?", + (normalized,), + ) + row = await cursor.fetchone() + if row is None: + return False + return bool(row[0]) + + async def list_disabled_skills( + self, + db_path: Path, + ) -> list[str]: + """Return the list of skill names currently disabled.""" + async with aiosqlite.connect(str(db_path)) as db: + cursor = await db.execute( + "SELECT skill_name FROM skill_states WHERE is_disabled = 1 ORDER BY skill_name ASC" + ) + rows = await cursor.fetchall() + return [row[0] for row in rows] + + # ------------------------------------------------------------------ + # YAML import / reload / update + # ------------------------------------------------------------------ + + async def import_skill( + self, + yaml_content: str, + skills_dir: str, + skill_registry: SkillRegistry | None = None, + tool_registry: Any = None, + ) -> dict[str, Any]: + """Validate YAML, write to file, and load into the registry. + + Args: + yaml_content: Raw YAML text for the skill config. + skills_dir: Target directory for the YAML file. The file is + written as ``{skills_dir}/{skill_name}.yaml``. + skill_registry: Optional :class:`SkillRegistry` to register + the loaded skill into. If ``None``, the skill is only + written to disk. + tool_registry: Optional :class:`ToolRegistry` for tool binding. + + Returns: + Dict with ``name`` and ``path`` of the imported skill. + + Raises: + ValueError: If the YAML is invalid, the skill name is invalid, + or the resolved path escapes ``skills_dir``. + """ + data = _validate_yaml_content(yaml_content) + raw_name = data["name"] + skill_name = _validate_skill_name(raw_name) + + os.makedirs(skills_dir, exist_ok=True) + file_path = os.path.join(skills_dir, f"{skill_name}.yaml") + + # Path traversal protection: the resolved path must stay inside + # skills_dir. (skill_name is regex-validated so this should always + # pass, but we double-check defensively.) + if not _is_within_directory(file_path, skills_dir): + raise ValueError(f"Resolved skill path escapes skills_dir: {file_path}") + + with open(file_path, "w", encoding="utf-8") as f: + f.write(yaml_content) + + if skill_registry is not None: + try: + from agentkit.skills.loader import SkillLoader + + loader = SkillLoader( + skill_registry=skill_registry, + tool_registry=tool_registry, + ) + loader.load_from_file(file_path) + except Exception: + # Remove the invalid YAML file and re-raise as ValueError + # so the route layer maps it to 400. + try: + os.remove(file_path) + except OSError: + pass + raise ValueError("Skill YAML written but registration failed") + + return {"name": skill_name, "path": file_path} + + async def reload_skill( + self, + skill_name: str, + skill_registry: SkillRegistry, + skills_dir: str, + tool_registry: Any = None, + ) -> dict[str, Any]: + """Unregister and reload a skill from its YAML file. + + Args: + skill_name: Skill name to reload (validated against the regex). + skill_registry: The live :class:`SkillRegistry` instance. + skills_dir: Directory containing the YAML file. + tool_registry: Optional :class:`ToolRegistry` for tool binding. + + Returns: + Dict with ``name``, ``path``, and ``status`` of the reload. + + Raises: + ValueError: If the skill name is invalid or the YAML file + cannot be found. + """ + normalized = _validate_skill_name(skill_name) + file_path = os.path.join(skills_dir, f"{normalized}.yaml") + if not os.path.isfile(file_path): + raise ValueError(f"Skill {normalized!r} not found: no YAML file at {file_path}") + + if not _is_within_directory(file_path, skills_dir): + raise ValueError(f"Resolved skill path escapes skills_dir: {file_path}") + + # Unregister the existing skill (if any) before reloading so the + # registry picks up the new config cleanly. + try: + skill_registry.unregister(normalized) + except Exception: # noqa: BLE001 — unregister is best-effort + logger.debug("Skill %r was not registered before reload", normalized) + + from agentkit.skills.loader import SkillLoader + + loader = SkillLoader( + skill_registry=skill_registry, + tool_registry=tool_registry, + ) + loader.load_from_file(file_path) + + return { + "name": normalized, + "path": file_path, + "status": "reloaded", + } + + async def update_skill_config( + self, + skill_name: str, + config_patch: dict[str, Any], + skills_dir: str, + skill_registry: SkillRegistry, + tool_registry: Any = None, + ) -> dict[str, Any]: + """Update a skill's YAML config in-place and reload it. + + Performs a shallow merge: top-level keys in ``config_patch`` + overwrite the existing YAML values. Nested dicts are replaced + wholesale (not deep-merged) to keep the semantics predictable. + + Args: + skill_name: Skill name to update (validated against the regex). + config_patch: Dict of top-level YAML keys to update. + skills_dir: Directory containing the YAML file. + skill_registry: The live :class:`SkillRegistry` instance. + tool_registry: Optional :class:`ToolRegistry` for tool binding. + + Returns: + Dict with ``name``, ``path``, and ``status`` of the update. + + Raises: + ValueError: If the skill name is invalid, the YAML file + cannot be found, or the patched config is invalid. + """ + normalized = _validate_skill_name(skill_name) + file_path = os.path.join(skills_dir, f"{normalized}.yaml") + if not os.path.isfile(file_path): + raise ValueError(f"Skill {normalized!r} not found: no YAML file at {file_path}") + + if not _is_within_directory(file_path, skills_dir): + raise ValueError(f"Resolved skill path escapes skills_dir: {file_path}") + + # Read existing YAML, apply patch, validate, write back. + with open(file_path, encoding="utf-8") as f: + existing = yaml.safe_load(f) or {} + + if not isinstance(existing, dict): + raise ValueError(f"Existing skill YAML at {file_path} is not a mapping") + + # Shallow merge: top-level keys in config_patch overwrite. + # The 'name' field is preserved (cannot rename via patch). + patched = {**existing, **config_patch, "name": existing.get("name", normalized)} + + # Validate the patched config by round-tripping through YAML. + try: + patched_yaml = yaml.safe_dump(patched, default_flow_style=False, allow_unicode=True) + except yaml.YAMLError as exc: + raise ValueError(f"Failed to serialize patched config: {exc}") from exc + + _validate_yaml_content(patched_yaml) + + with open(file_path, "w", encoding="utf-8") as f: + f.write(patched_yaml) + + # Reload the skill into the registry. + try: + skill_registry.unregister(normalized) + except Exception: # noqa: BLE001 — unregister is best-effort + logger.debug("Skill %r was not registered before update", normalized) + + from agentkit.skills.loader import SkillLoader + + loader = SkillLoader( + skill_registry=skill_registry, + tool_registry=tool_registry, + ) + loader.load_from_file(file_path) + + return { + "name": normalized, + "path": file_path, + "status": "updated", + } + + +# --------------------------------------------------------------------------- +# Module-level singleton (overridable in tests via set_skill_service) +# --------------------------------------------------------------------------- + + +_skill_service: SkillService | None = None + + +def get_skill_service() -> SkillService: + """Return the process-wide :class:`SkillService` (lazy singleton).""" + global _skill_service + if _skill_service is None: + _skill_service = SkillService() + return _skill_service + + +def set_skill_service(service: SkillService | None) -> None: + """Inject a custom :class:`SkillService` (used by tests).""" + global _skill_service + _skill_service = service diff --git a/src/agentkit/server/auth/models.py b/src/agentkit/server/auth/models.py index a42c857..3ced901 100644 --- a/src/agentkit/server/auth/models.py +++ b/src/agentkit/server/auth/models.py @@ -351,6 +351,24 @@ class DepartmentQuotaModel(Base): updated_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso) +class SkillStateModel(Base): + """Skill enable/disable state (V4 — Admin Console U6). + + Each row records whether a named skill has been disabled by an + admin via the ``/admin/skills/{name}/disable`` endpoint. Skills + with no row here are considered enabled (the default). ``skill_name`` + references the skill registry identifier (not a DB FK — skills are + defined in YAML configs). + """ + + __tablename__ = "skill_states" + + skill_name: Mapped[str] = mapped_column(String(128), primary_key=True) + is_disabled: Mapped[bool] = mapped_column(default=True, nullable=False) + disabled_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso) + disabled_by: Mapped[str | None] = mapped_column(String(36), nullable=True) + + # --------------------------------------------------------------------------- # Schema DDL (kept in sync with the models above for aiosqlite bootstrap) # --------------------------------------------------------------------------- @@ -565,6 +583,17 @@ CREATE TABLE IF NOT EXISTS department_quotas ( ); CREATE INDEX IF NOT EXISTS idx_department_quotas_department_id ON department_quotas(department_id); + +-- V4: skill_states records admin-disabled skills (U6 — Admin Console). +-- A skill with no row here is considered enabled (the default). Only +-- disabled skills have a row, with is_disabled=1. disabled_by records +-- the admin user id who disabled the skill (audit trail). +CREATE TABLE IF NOT EXISTS skill_states ( + skill_name TEXT PRIMARY KEY, + is_disabled INTEGER NOT NULL DEFAULT 1, + disabled_at TEXT NOT NULL, + disabled_by TEXT +); """ @@ -580,7 +609,10 @@ CREATE INDEX IF NOT EXISTS idx_department_quotas_department_id # V3 (2026-06-21, Admin Console): added departments, user_departments, # department_skill_bindings, department_kb_bindings, department_quotas. # No backfill needed — all new tables are additive. -_SCHEMA_VERSION = 3 +# +# V4 (2026-06-21, Admin Console U6): added skill_states table for +# admin-driven skill enable/disable. No backfill needed — additive. +_SCHEMA_VERSION = 4 _META_SCHEMA_VERSION_KEY = "schema_version" @@ -803,3 +835,18 @@ def user_department_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> di "department_id": row["department_id"], "created_at": row["created_at"], } + + +def skill_state_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any]: + """Convert a ``skill_states`` row into a JSON-safe dict. + + The ``is_disabled`` field is normalized to a Python ``bool`` (the DB + stores 0/1). ``disabled_by`` is ``None`` when the disabling admin + is not recorded (e.g. legacy rows or system-initiated disables). + """ + return { + "skill_name": row["skill_name"], + "is_disabled": bool(row["is_disabled"]), + "disabled_at": row["disabled_at"], + "disabled_by": row["disabled_by"], + } diff --git a/src/agentkit/server/routes/admin.py b/src/agentkit/server/routes/admin.py index de1d1ac..1a0d739 100644 --- a/src/agentkit/server/routes/admin.py +++ b/src/agentkit/server/routes/admin.py @@ -23,11 +23,13 @@ from fastapi import APIRouter, Depends, HTTPException, Request from pydantic import BaseModel, ConfigDict from agentkit.server.admin.department_service import get_department_service +from agentkit.server.admin.kb_service import get_kb_service from agentkit.server.admin.llm_config_service import ( LlmConfigService, get_llm_config_service, ) from agentkit.server.admin.quota_service import get_quota_service +from agentkit.server.admin.skill_service import get_skill_service from agentkit.server.admin.user_service import get_user_service from agentkit.server.auth.dependencies import require_authenticated from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH, init_auth_db @@ -854,3 +856,285 @@ async def delete_department_quota( if not deleted: raise HTTPException(status_code=404, detail="Quota not found") return {"deleted": True} + + +# --------------------------------------------------------------------------- +# Skill management endpoints (U6) — enable/disable + import/reload +# --------------------------------------------------------------------------- + + +def _get_skills_dir(request: Request) -> str: + """Resolve the skills directory from ``app.state.server_config``. + + Mirrors the logic in :func:`agentkit.server.routes.skills._get_skills_dir` + so admin endpoints install/reload skills into the same directory as + the existing ``/skills/install`` route. + """ + server_config = getattr(request.app.state, "server_config", None) + if server_config and getattr(server_config, "skill_paths", None): + first_path = Path(server_config.skill_paths[0]) + if first_path.is_dir(): + return str(first_path) + # Fallback: configs/skills/ relative to cwd (matches routes.skills). + import os + + return os.path.join(os.getcwd(), "configs", "skills") + + +def _get_skill_registry(request: Request) -> Any: + """Return the live :class:`SkillRegistry` from ``app.state``. + + Raises HTTPException(500) if the registry is missing — admin skill + endpoints cannot function without it. + """ + registry = getattr(request.app.state, "skill_registry", None) + if registry is None: + raise HTTPException( + status_code=500, + detail="Skill registry not initialized on app.state", + ) + return registry + + +class SkillImportRequest(BaseModel): + """Body for ``POST /admin/skills/import``.""" + + model_config = ConfigDict(extra="forbid") + + yaml_content: str + + +class SkillUpdateRequest(BaseModel): + """Body for ``PATCH /admin/skills/{name}``.""" + + model_config = ConfigDict(extra="forbid") + + config: dict[str, Any] + + +@admin_router.post("/skills/{name}/enable") +async def enable_skill( + name: str, + request: Request, + admin: dict[str, Any] = Depends(_require_admin), +) -> dict[str, Any]: + """Enable a previously-disabled skill. + + Returns 200 ``{enabled: true}`` on success. Idempotent — returns + 200 even if the skill was not disabled. + """ + db_path = await _ensure_db(request) + svc = get_skill_service() + try: + enabled = await svc.enable_skill(db_path, name) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + return {"enabled": enabled, "skill_name": name} + + +@admin_router.post("/skills/{name}/disable") +async def disable_skill( + name: str, + request: Request, + admin: dict[str, Any] = Depends(_require_admin), +) -> dict[str, Any]: + """Disable a skill (hides it from ``GET /skills``). + + Returns 200 with the disabled skill state dict. + """ + db_path = await _ensure_db(request) + svc = get_skill_service() + try: + return await svc.disable_skill(db_path, name, disabled_by=admin.get("user_id")) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + +@admin_router.patch("/skills/{name}") +async def update_skill( + name: str, + payload: SkillUpdateRequest, + request: Request, + admin: dict[str, Any] = Depends(_require_admin), +) -> dict[str, Any]: + """Update a skill's YAML config in-place and reload it. + + Returns 200 with the updated skill info. Returns 404 if the skill + YAML file does not exist, 400 if the patched config is invalid. + """ + skills_dir = _get_skills_dir(request) + registry = _get_skill_registry(request) + svc = get_skill_service() + try: + return await svc.update_skill_config( + name, + payload.config, + skills_dir, + registry, + ) + except ValueError as exc: + msg = str(exc) + if "not found" in msg: + raise HTTPException(status_code=404, detail=msg) from exc + raise HTTPException(status_code=400, detail=msg) from exc + + +@admin_router.post("/skills/import") +async def import_skill( + payload: SkillImportRequest, + request: Request, + admin: dict[str, Any] = Depends(_require_admin), +) -> dict[str, Any]: + """Import a skill from YAML content. + + Writes the YAML to ``{skills_dir}/{name}.yaml`` and registers it + in the live :class:`SkillRegistry`. Returns 200 with the imported + skill info. Returns 400 if the YAML is invalid or the skill name + is invalid. + """ + skills_dir = _get_skills_dir(request) + registry = _get_skill_registry(request) + svc = get_skill_service() + try: + return await svc.import_skill( + payload.yaml_content, + skills_dir, + skill_registry=registry, + ) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + +@admin_router.post("/skills/{name}/reload") +async def reload_skill( + name: str, + request: Request, + admin: dict[str, Any] = Depends(_require_admin), +) -> dict[str, Any]: + """Reload a skill from its YAML file. + + Returns 200 with the reloaded skill info. Returns 404 if the YAML + file does not exist, 400 if the skill name is invalid. + """ + skills_dir = _get_skills_dir(request) + registry = _get_skill_registry(request) + svc = get_skill_service() + try: + return await svc.reload_skill(name, registry, skills_dir) + except ValueError as exc: + msg = str(exc) + if "not found" in msg: + raise HTTPException(status_code=404, detail=msg) from exc + raise HTTPException(status_code=400, detail=msg) from exc + + +# --------------------------------------------------------------------------- +# KB management endpoints (U6) — documents + sources +# --------------------------------------------------------------------------- + + +class KbDocumentUploadRequest(BaseModel): + """Body for ``POST /admin/kb/documents``.""" + + model_config = ConfigDict(extra="forbid") + + filename: str + content: str + source_id: str = "" + department_id: str | None = None + + +@admin_router.get("/kb/documents") +async def list_kb_documents( + request: Request, + source_id: str | None = None, + department_id: str | None = None, + admin: dict[str, Any] = Depends(_require_admin), +) -> dict[str, Any]: + """List KB documents (admin sees all). + + Query params: ``source_id`` (optional filter), ``department_id`` + (optional filter — when set, only documents bound to that + department or global documents are returned). + """ + svc = get_kb_service() + # Admin sees everything. If department_id is provided as a query + # filter, narrow to that department + global docs. + if department_id is not None: + dept_ids = [department_id] + else: + dept_ids = None # admin: no filtering + documents = svc.list_documents(department_ids=dept_ids, source_id=source_id) + return {"documents": documents} + + +@admin_router.post("/kb/documents", status_code=201) +async def upload_kb_document( + payload: KbDocumentUploadRequest, + request: Request, + admin: dict[str, Any] = Depends(_require_admin), +) -> dict[str, Any]: + """Upload a KB document with optional department binding. + + Returns 201 with the uploaded document dict. + """ + svc = get_kb_service() + return svc.upload_document( + filename=payload.filename, + content=payload.content.encode("utf-8"), + source_id=payload.source_id, + department_id=payload.department_id, + ) + + +@admin_router.delete("/kb/documents/{document_id}") +async def delete_kb_document( + document_id: str, + request: Request, + admin: dict[str, Any] = Depends(_require_admin), +) -> dict[str, Any]: + """Delete a KB document by id. + + Returns 200 ``{deleted: true}`` on success, 404 if not found. + """ + svc = get_kb_service() + deleted = svc.delete_document(document_id) + if not deleted: + raise HTTPException(status_code=404, detail="Document not found") + return {"deleted": True} + + +@admin_router.post("/kb/sources/{source_id}/sync") +async def sync_kb_source( + source_id: str, + request: Request, + admin: dict[str, Any] = Depends(_require_admin), +) -> dict[str, Any]: + """Trigger a sync for a KB source. + + Returns 200 with the sync status. Returns 404 if the source does + not exist. + """ + svc = get_kb_service() + try: + return svc.sync_source(source_id) + except ValueError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + + +@admin_router.post("/kb/sources/{source_id}/rebuild") +async def rebuild_kb_source( + source_id: str, + request: Request, + admin: dict[str, Any] = Depends(_require_admin), +) -> dict[str, Any]: + """Rebuild the index for a KB source. + + Returns 200 with the rebuild status. Returns 404 if the source + does not exist. + """ + svc = get_kb_service() + try: + return svc.rebuild_index(source_id) + except ValueError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc diff --git a/src/agentkit/server/routes/kb_management.py b/src/agentkit/server/routes/kb_management.py index 7919ed8..680a601 100644 --- a/src/agentkit/server/routes/kb_management.py +++ b/src/agentkit/server/routes/kb_management.py @@ -73,6 +73,7 @@ class KnowledgeSource: status: str = "active" document_count: int = 0 last_synced: str | None = None + department_id: str | None = None @dataclass @@ -83,6 +84,7 @@ class UploadedDocument: chunks: int status: str created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + department_id: str | None = None class KnowledgeSourceStore: @@ -155,6 +157,23 @@ class KnowledgeSourceStore: _source_store = KnowledgeSourceStore() +def get_source_store() -> KnowledgeSourceStore: + """Return the process-wide :class:`KnowledgeSourceStore` singleton. + + Exposed as a function (rather than importing ``_source_store`` + directly) so tests can swap it out via monkeypatch, and so the + :class:`KbService` wrapper can access the store without reaching + into a private module attribute. + """ + return _source_store + + +def set_source_store(store: KnowledgeSourceStore) -> None: + """Inject a custom :class:`KnowledgeSourceStore` (used by tests).""" + global _source_store + _source_store = store + + # --------------------------------------------------------------------------- # Request / Response models # --------------------------------------------------------------------------- diff --git a/src/agentkit/server/routes/skill_management.py b/src/agentkit/server/routes/skill_management.py index 9465178..882ad76 100644 --- a/src/agentkit/server/routes/skill_management.py +++ b/src/agentkit/server/routes/skill_management.py @@ -54,9 +54,7 @@ def _skill_to_info(skill: Any) -> dict[str, Any]: if hasattr(skill, "config") and hasattr(skill.config, "capabilities"): caps = skill.config.capabilities if isinstance(caps, list): - capabilities = [ - c.tag if hasattr(c, "tag") else str(c) for c in caps - ] + capabilities = [c.tag if hasattr(c, "tag") else str(c) for c in caps] elif isinstance(caps, dict): capabilities = list(caps.keys()) @@ -160,7 +158,7 @@ async def check_skill_health(skill_name: str, req: Request): """Check the health of a specific skill.""" skill_registry = req.app.state.skill_registry try: - skill = skill_registry.get(skill_name) + skill_registry.get(skill_name) except Exception: raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found") @@ -213,17 +211,43 @@ async def list_capabilities(req: Request): @router.post("/skill-management/skills/{skill_name}/reload") async def reload_skill(skill_name: str, req: Request): - """Reload a skill configuration.""" + """Reload a skill configuration from its YAML file.""" skill_registry = req.app.state.skill_registry + # Verify the skill is currently registered (404 if not). try: - skill = skill_registry.get(skill_name) + skill_registry.get(skill_name) except Exception: raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found") - # In a full implementation, this would reload the skill from its config source - # For now, just return success + # Resolve the skills directory (mirrors routes.skills._get_skills_dir). + import os + + skills_dir: str + server_config = getattr(req.app.state, "server_config", None) + if server_config and getattr(server_config, "skill_paths", None): + from pathlib import Path as _P + + first_path = _P(server_config.skill_paths[0]) + if first_path.is_dir(): + skills_dir = str(first_path) + else: + skills_dir = os.path.join(os.getcwd(), "configs", "skills") + else: + skills_dir = os.path.join(os.getcwd(), "configs", "skills") + + from agentkit.server.admin.skill_service import get_skill_service + + svc = get_skill_service() + try: + result = await svc.reload_skill(skill_name, skill_registry, skills_dir) + except ValueError as exc: + msg = str(exc) + if "not found" in msg: + raise HTTPException(status_code=404, detail=msg) from exc + raise HTTPException(status_code=400, detail=msg) from exc + return { - "skill_name": skill_name, - "status": "reloaded", + "skill_name": result["name"], + "status": result["status"], "message": f"技能 '{skill_name}' 已重新加载", } diff --git a/src/agentkit/server/routes/skills.py b/src/agentkit/server/routes/skills.py index ee03c8d..abb8a6c 100644 --- a/src/agentkit/server/routes/skills.py +++ b/src/agentkit/server/routes/skills.py @@ -49,6 +49,7 @@ def _get_skills_dir(req: Request) -> str: if server_config and server_config.skill_paths: # Use the first configured skill path as the install target from pathlib import Path as _P + first_path = _P(server_config.skill_paths[0]) if first_path.is_dir(): return str(first_path) @@ -59,12 +60,16 @@ def _get_skills_dir(req: Request) -> str: def _validate_source_url(source: str) -> None: """Validate that a source URL points to an allowed domain (SSRF mitigation).""" from urllib.parse import urlparse + parsed = urlparse(source) if parsed.scheme not in ("https", "http"): - raise HTTPException(status_code=400, detail=f"Invalid source URL scheme: only http/https allowed") + raise HTTPException( + status_code=400, detail="Invalid source URL scheme: only http/https allowed" + ) # Block private/internal IPs by checking hostname import ipaddress import socket + hostname = parsed.hostname if hostname: try: @@ -82,12 +87,15 @@ def _validate_source_url(source: str) -> None: # Check domain allowlist for source URLs if hostname and hostname not in _ALLOWED_DOWNLOAD_DOMAINS: # Allow but log a warning for non-allowlisted domains - logger.warning(f"Source URL domain '{hostname}' is not in the allowlist: {_ALLOWED_DOWNLOAD_DOMAINS}") + logger.warning( + f"Source URL domain '{hostname}' is not in the allowlist: {_ALLOWED_DOWNLOAD_DOMAINS}" + ) def _validate_yaml_content(content: str) -> dict: """Validate YAML content before writing to disk. Returns parsed dict.""" import yaml + try: data = yaml.safe_load(content) except yaml.YAMLError as e: @@ -156,16 +164,37 @@ async def list_skills( Admin users (``role == "admin"``) bypass filtering and see all registered skills. Unauthenticated callers (API-key clients) see only global skills. + + Disabled-skill filtering (U6): skills marked as disabled by an + admin via ``POST /admin/skills/{name}/disable`` are excluded from + the response for ALL callers (admins included). This keeps the + public skill list in sync with the admin enable/disable state. """ skill_registry = req.app.state.skill_registry skills = skill_registry.list_skills() + # U6: filter out admin-disabled skills (applies to all callers). + db_path = _resolve_db_path(req) + if db_path.exists(): + try: + from agentkit.server.admin.skill_service import get_skill_service + + svc = get_skill_service() + disabled_names = set(await svc.list_disabled_skills(db_path)) + except Exception: # noqa: BLE001 — never block listing on DB errors + logger.exception("Failed to load disabled skills list — skipping filter") + disabled_names = set() + else: + disabled_names = set() + + if disabled_names: + skills = [s for s in skills if s.name not in disabled_names] + # Admins bypass department filtering. if not dept_ctx.should_filter: return _serialize_skills(skills) # Non-admin: filter by department bindings. - db_path = _resolve_db_path(req) all_names = [s.name for s in skills] try: visible_names = await filter_skills_by_department( @@ -218,15 +247,13 @@ async def mention_suggest(q: str = "", req: Request = None): if query: skills = [ - s for s in skills + s + for s in skills if query in s.name.lower() or (s.config.description and query in s.config.description.lower()) ] - return [ - {"name": s.name, "description": s.config.description or ""} - for s in skills[:8] - ] + return [{"name": s.name, "description": s.config.description or ""} for s in skills[:8]] @router.post("/skills/install") @@ -246,7 +273,9 @@ async def install_skill(request: InstallSkillRequest, req: Request): if source and source.startswith("http"): _validate_source_url(source) try: - async with httpx.AsyncClient(timeout=30, follow_redirects=True, max_redirects=3) as client: + async with httpx.AsyncClient( + timeout=30, follow_redirects=True, max_redirects=3 + ) as client: resp = await client.get(source) resp.raise_for_status() yaml_content = resp.text @@ -260,7 +289,9 @@ async def install_skill(request: InstallSkillRequest, req: Request): # Verify the path is within the skills directory skills_dir_base = _get_skills_dir(req) if not os.path.realpath(local_path).startswith(os.path.realpath(skills_dir_base)): - raise HTTPException(status_code=400, detail="Local file path must be within the skills directory") + raise HTTPException( + status_code=400, detail="Local file path must be within the skills directory" + ) try: with open(local_path, encoding="utf-8") as f: yaml_content = f.read() @@ -290,12 +321,17 @@ async def install_skill(request: InstallSkillRequest, req: Request): # Fallback: try a simpler search search_query2 = f"{skill_name} skill" encoded_query2 = urllib.parse.quote(search_query2) - github_api2 = f"https://api.github.com/search/code?q={encoded_query2}+extension:yaml&per_page=5" + github_api2 = ( + f"https://api.github.com/search/code?q={encoded_query2}+extension:yaml&per_page=5" + ) try: async with httpx.AsyncClient(timeout=15) as client: gh_resp2 = await client.get( github_api2, - headers={"Accept": "application/vnd.github.v3+json", "User-Agent": "agentkit"}, + headers={ + "Accept": "application/vnd.github.v3+json", + "User-Agent": "agentkit", + }, ) items = gh_resp2.json().get("items", []) except Exception: @@ -310,13 +346,19 @@ async def install_skill(request: InstallSkillRequest, req: Request): if raw_url: # Validate the URL is from github.com before transforming if not raw_url.startswith("https://github.com/"): - raise HTTPException(status_code=400, detail="Search result URL is not from github.com") - raw_url = raw_url.replace("github.com", "raw.githubusercontent.com").replace("/blob/", "/") + raise HTTPException( + status_code=400, detail="Search result URL is not from github.com" + ) + raw_url = raw_url.replace("github.com", "raw.githubusercontent.com").replace( + "/blob/", "/" + ) else: raise HTTPException(status_code=404, detail="Could not construct download URL") try: - async with httpx.AsyncClient(timeout=30, follow_redirects=True, max_redirects=3) as client: + async with httpx.AsyncClient( + timeout=30, follow_redirects=True, max_redirects=3 + ) as client: resp = await client.get(raw_url) resp.raise_for_status() yaml_content = resp.text @@ -342,6 +384,7 @@ async def install_skill(request: InstallSkillRequest, req: Request): registration_ok = False try: from agentkit.skills.loader import SkillLoader + loader = SkillLoader( skill_registry=skill_registry, tool_registry=tool_registry, @@ -357,7 +400,7 @@ async def install_skill(request: InstallSkillRequest, req: Request): os.remove(file_path) except Exception: pass - raise HTTPException(status_code=500, detail=f"Skill downloaded but registration failed") + raise HTTPException(status_code=500, detail="Skill downloaded but registration failed") return { "status": "installed", @@ -387,7 +430,9 @@ async def uninstall_skill(name: str, req: Request): yaml_path = os.path.join(skills_dir, f"{validated_name}.yaml") # Verify resolved path stays within skills_dir - if os.path.exists(yaml_path) and os.path.realpath(yaml_path).startswith(os.path.realpath(skills_dir)): + if os.path.exists(yaml_path) and os.path.realpath(yaml_path).startswith( + os.path.realpath(skills_dir) + ): os.remove(yaml_path) return {"status": "uninstalled", "name": validated_name} @@ -419,8 +464,7 @@ async def create_pipeline(request: CreatePipelineRequest, req: Request): return { "name": pipeline.name, "steps": [ - {"skill_name": s["skill_name"], "step_index": i} - for i, s in enumerate(request.steps) + {"skill_name": s["skill_name"], "step_index": i} for i, s in enumerate(request.steps) ], } diff --git a/tests/integration/admin/test_skill_kb_routes.py b/tests/integration/admin/test_skill_kb_routes.py new file mode 100644 index 0000000..45f61bf --- /dev/null +++ b/tests/integration/admin/test_skill_kb_routes.py @@ -0,0 +1,522 @@ +"""Integration tests for the skill/KB admin routes (U6). + +Uses FastAPI TestClient with a test app that mounts the +``admin_router`` from ``routes.admin`` plus the public ``skills`` +router from ``routes.skills`` (to verify disabled-skill filtering). +The ``_require_admin`` dependency is overridden via +``app.dependency_overrides`` so the tests don't need real JWTs. + +The :class:`SkillRegistry` is a real instance with a test skill +loaded from a temp YAML file, so import/reload/update endpoints +exercise the real SkillLoader pipeline. +""" + +from __future__ import annotations + +import os +import uuid +from pathlib import Path +from typing import Any + +import pytest +from fastapi import FastAPI, HTTPException +from fastapi.testclient import TestClient + +from agentkit.server.admin.kb_service import set_kb_service +from agentkit.server.admin.skill_service import set_skill_service +from agentkit.server.auth.models import init_auth_db +from agentkit.server.routes import admin as admin_routes_module +from agentkit.server.routes import skills as skills_routes_module +from agentkit.skills.registry import SkillRegistry + + +# --------------------------------------------------------------------------- +# Test data +# --------------------------------------------------------------------------- + + +_VALID_SKILL_YAML = """\ +name: admin_test_skill +agent_type: simple_generation +version: "1.0.0" +description: "A test skill for admin route testing" +task_mode: llm_generate +execution_mode: direct +max_steps: 1 +prompt: + identity: "Test" + instructions: "Handle test" +tools: [] +""" + +_VALID_SKILL_YAML_2 = """\ +name: another_test_skill +agent_type: simple_generation +version: "1.0.0" +description: "Another test skill" +task_mode: llm_generate +execution_mode: direct +max_steps: 1 +prompt: + identity: "Test2" + instructions: "Handle test 2" +tools: [] +""" + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +async def tmp_auth_db(tmp_path: Path) -> Path: + db_path = tmp_path / "admin_skill_kb.db" + await init_auth_db(db_path) + return db_path + + +@pytest.fixture +def skills_dir(tmp_path: Path) -> str: + """A temp skills directory for YAML files.""" + d = tmp_path / "skills" + d.mkdir() + return str(d) + + +@pytest.fixture +def skill_registry() -> SkillRegistry: + return SkillRegistry() + + +@pytest.fixture(autouse=True) +def _reset_singletons(): + """Reset SkillService and KbService singletons before/after each test.""" + set_skill_service(None) + set_kb_service(None) + yield + set_skill_service(None) + set_kb_service(None) + + +def _make_admin_user() -> dict[str, Any]: + return {"user_id": "admin-1", "username": "admin", "role": "admin"} + + +def _raise_forbidden() -> dict[str, Any]: + """Dependency override that simulates a non-admin (403) response.""" + raise HTTPException(status_code=403, detail="Admin permission required") + + +@pytest.fixture +def admin_app( + tmp_auth_db: Path, + skills_dir: str, + skill_registry: SkillRegistry, +) -> FastAPI: + """A minimal FastAPI app with admin + skills routers mounted. + + The ``_require_admin`` dependency is overridden to return a fake + admin user. The :class:`SkillRegistry` is set on ``app.state`` so + the admin skill endpoints can access it. + """ + app = FastAPI() + app.state.auth_db_path = str(tmp_auth_db) + app.state.skill_registry = skill_registry + + # Build a minimal server_config with skill_paths pointing at the + # temp skills_dir, so the admin endpoints write YAML there. + class _FakeServerConfig: + skill_paths = [skills_dir] + + app.state.server_config = _FakeServerConfig() + + app.include_router(admin_routes_module.admin_router, prefix="/api/v1") + app.include_router(skills_routes_module.router, prefix="/api/v1") + + # Default: allow admin access. + app.dependency_overrides[admin_routes_module._require_admin] = lambda: _make_admin_user() + # Override the department context to admin (bypass filtering) so + # GET /skills returns all skills (the disabled filter still applies). + from agentkit.server.admin.context import DepartmentContext + + app.dependency_overrides[skills_routes_module.get_department_context] = lambda: ( + DepartmentContext(user_id="admin-1", department_ids=[], is_admin=True) + ) + return app + + +@pytest.fixture +def admin_client(admin_app: FastAPI) -> TestClient: + return TestClient(admin_app) + + +def _write_skill_yaml(skills_dir: str, name: str, content: str) -> str: + """Write a skill YAML file to the skills dir and return the path.""" + path = os.path.join(skills_dir, f"{name}.yaml") + with open(path, "w", encoding="utf-8") as f: + f.write(content) + return path + + +# --------------------------------------------------------------------------- +# Skill enable/disable +# --------------------------------------------------------------------------- + + +class TestSkillEnableDisable: + def test_disable_skill_returns_200(self, admin_client: TestClient): + resp = admin_client.post("/api/v1/admin/skills/admin_test_skill/disable") + assert resp.status_code == 200, resp.text + body = resp.json() + assert body["skill_name"] == "admin_test_skill" + assert body["is_disabled"] is True + + def test_disable_then_get_skills_excludes_it( + self, + admin_client: TestClient, + skills_dir: str, + skill_registry: SkillRegistry, + ): + # Write a skill YAML and register it. + _write_skill_yaml(skills_dir, "admin_test_skill", _VALID_SKILL_YAML) + from agentkit.skills.loader import SkillLoader + + SkillLoader(skill_registry=skill_registry).load_from_file( + os.path.join(skills_dir, "admin_test_skill.yaml") + ) + + # Verify the skill is initially listed. + resp = admin_client.get("/api/v1/skills") + assert resp.status_code == 200 + names = [s["name"] for s in resp.json()] + assert "admin_test_skill" in names + + # Disable the skill. + resp = admin_client.post("/api/v1/admin/skills/admin_test_skill/disable") + assert resp.status_code == 200 + + # GET /skills should now exclude it. + resp = admin_client.get("/api/v1/skills") + assert resp.status_code == 200 + names = [s["name"] for s in resp.json()] + assert "admin_test_skill" not in names + + def test_enable_skill_returns_200( + self, + admin_client: TestClient, + skills_dir: str, + skill_registry: SkillRegistry, + ): + # Write and register the skill, then disable it. + _write_skill_yaml(skills_dir, "admin_test_skill", _VALID_SKILL_YAML) + from agentkit.skills.loader import SkillLoader + + SkillLoader(skill_registry=skill_registry).load_from_file( + os.path.join(skills_dir, "admin_test_skill.yaml") + ) + admin_client.post("/api/v1/admin/skills/admin_test_skill/disable") + + # Enable it. + resp = admin_client.post("/api/v1/admin/skills/admin_test_skill/enable") + assert resp.status_code == 200 + assert resp.json()["enabled"] is True + + # GET /skills should now include it. + resp = admin_client.get("/api/v1/skills") + names = [s["name"] for s in resp.json()] + assert "admin_test_skill" in names + + def test_enable_skill_not_disabled_returns_200(self, admin_client: TestClient): + """Enabling a skill that wasn't disabled returns enabled=False.""" + resp = admin_client.post("/api/v1/admin/skills/never_disabled/enable") + assert resp.status_code == 200 + assert resp.json()["enabled"] is False + + def test_disable_skill_invalid_name_returns_400(self, admin_client: TestClient): + resp = admin_client.post("/api/v1/admin/skills/Has Spaces/disable") + assert resp.status_code == 400 + + def test_non_admin_cannot_disable_skill(self, admin_app: FastAPI): + admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden + client = TestClient(admin_app) + resp = client.post("/api/v1/admin/skills/some_skill/disable") + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# Skill import +# --------------------------------------------------------------------------- + + +class TestSkillImport: + def test_import_valid_yaml_returns_200( + self, + admin_client: TestClient, + skills_dir: str, + skill_registry: SkillRegistry, + ): + resp = admin_client.post( + "/api/v1/admin/skills/import", + json={"yaml_content": _VALID_SKILL_YAML}, + ) + assert resp.status_code == 200, resp.text + body = resp.json() + assert body["name"] == "admin_test_skill" + assert os.path.isfile(body["path"]) + # Skill should be registered. + assert skill_registry.has_skill("admin_test_skill") + + def test_import_invalid_yaml_returns_400(self, admin_client: TestClient): + resp = admin_client.post( + "/api/v1/admin/skills/import", + json={"yaml_content": "not: valid: yaml: ["}, + ) + assert resp.status_code == 400 + + def test_import_missing_name_returns_400(self, admin_client: TestClient): + resp = admin_client.post( + "/api/v1/admin/skills/import", + json={"yaml_content": "agent_type: test\n"}, + ) + assert resp.status_code == 400 + + def test_import_invalid_skill_name_returns_400(self, admin_client: TestClient): + bad_yaml = 'name: "Has Spaces"\nagent_type: test\n' + resp = admin_client.post( + "/api/v1/admin/skills/import", + json={"yaml_content": bad_yaml}, + ) + assert resp.status_code == 400 + + def test_non_admin_cannot_import(self, admin_app: FastAPI): + admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden + client = TestClient(admin_app) + resp = client.post( + "/api/v1/admin/skills/import", + json={"yaml_content": _VALID_SKILL_YAML}, + ) + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# Skill reload +# --------------------------------------------------------------------------- + + +class TestSkillReload: + def test_reload_existing_skill_returns_200( + self, + admin_client: TestClient, + skills_dir: str, + skill_registry: SkillRegistry, + ): + # Write and register the skill first. + _write_skill_yaml(skills_dir, "admin_test_skill", _VALID_SKILL_YAML) + from agentkit.skills.loader import SkillLoader + + SkillLoader(skill_registry=skill_registry).load_from_file( + os.path.join(skills_dir, "admin_test_skill.yaml") + ) + + resp = admin_client.post("/api/v1/admin/skills/admin_test_skill/reload") + assert resp.status_code == 200, resp.text + body = resp.json() + assert body["name"] == "admin_test_skill" + assert body["status"] == "reloaded" + + def test_reload_nonexistent_skill_returns_404(self, admin_client: TestClient): + resp = admin_client.post("/api/v1/admin/skills/nonexistent/reload") + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# Skill update (PATCH) +# --------------------------------------------------------------------------- + + +class TestSkillUpdate: + def test_update_skill_returns_200( + self, + admin_client: TestClient, + skills_dir: str, + skill_registry: SkillRegistry, + ): + # Write and register the skill first. + _write_skill_yaml(skills_dir, "admin_test_skill", _VALID_SKILL_YAML) + from agentkit.skills.loader import SkillLoader + + SkillLoader(skill_registry=skill_registry).load_from_file( + os.path.join(skills_dir, "admin_test_skill.yaml") + ) + + resp = admin_client.patch( + "/api/v1/admin/skills/admin_test_skill", + json={"config": {"description": "Updated via admin API"}}, + ) + assert resp.status_code == 200, resp.text + body = resp.json() + assert body["status"] == "updated" + + # Verify the YAML file was updated. + import yaml + + with open(body["path"], encoding="utf-8") as f: + data = yaml.safe_load(f) + assert data["description"] == "Updated via admin API" + + def test_update_nonexistent_skill_returns_404(self, admin_client: TestClient): + resp = admin_client.patch( + "/api/v1/admin/skills/nonexistent", + json={"config": {"description": "x"}}, + ) + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# KB document endpoints +# --------------------------------------------------------------------------- + + +class TestKbDocumentRoutes: + def test_list_documents_returns_200(self, admin_client: TestClient): + resp = admin_client.get("/api/v1/admin/kb/documents") + assert resp.status_code == 200, resp.text + body = resp.json() + assert "documents" in body + assert isinstance(body["documents"], list) + + def test_upload_document_returns_201_with_department_id(self, admin_client: TestClient): + dept_id = str(uuid.uuid4()) + resp = admin_client.post( + "/api/v1/admin/kb/documents", + json={ + "filename": "test.txt", + "content": "hello world", + "department_id": dept_id, + }, + ) + assert resp.status_code == 201, resp.text + body = resp.json() + assert body["filename"] == "test.txt" + assert body["department_id"] == dept_id + assert "document_id" in body + + def test_upload_document_without_department_id(self, admin_client: TestClient): + resp = admin_client.post( + "/api/v1/admin/kb/documents", + json={"filename": "global.txt", "content": "global content"}, + ) + assert resp.status_code == 201 + body = resp.json() + assert body["department_id"] is None + + def test_upload_then_delete_document(self, admin_client: TestClient): + # Upload + resp = admin_client.post( + "/api/v1/admin/kb/documents", + json={"filename": "delete_me.txt", "content": "x"}, + ) + assert resp.status_code == 201 + doc_id = resp.json()["document_id"] + + # Delete + resp = admin_client.delete(f"/api/v1/admin/kb/documents/{doc_id}") + assert resp.status_code == 200 + assert resp.json() == {"deleted": True} + + # Second delete should 404. + resp = admin_client.delete(f"/api/v1/admin/kb/documents/{doc_id}") + assert resp.status_code == 404 + + def test_delete_nonexistent_document_returns_404(self, admin_client: TestClient): + resp = admin_client.delete("/api/v1/admin/kb/documents/nonexistent-id") + assert resp.status_code == 404 + + def test_list_documents_filters_by_department_id(self, admin_client: TestClient): + dept_a = str(uuid.uuid4()) + dept_b = str(uuid.uuid4()) + + # Upload 3 docs: one for dept_a, one for dept_b, one global. + admin_client.post( + "/api/v1/admin/kb/documents", + json={"filename": "a.txt", "content": "a", "department_id": dept_a}, + ) + admin_client.post( + "/api/v1/admin/kb/documents", + json={"filename": "b.txt", "content": "b", "department_id": dept_b}, + ) + admin_client.post( + "/api/v1/admin/kb/documents", + json={"filename": "global.txt", "content": "g"}, + ) + + # Filter by dept_a: should see dept_a's doc + global doc. + resp = admin_client.get("/api/v1/admin/kb/documents", params={"department_id": dept_a}) + assert resp.status_code == 200 + docs = resp.json()["documents"] + dept_ids = {d["department_id"] for d in docs} + assert dept_a in dept_ids + assert None in dept_ids # global doc + assert dept_b not in dept_ids + + def test_non_admin_cannot_list_documents(self, admin_app: FastAPI): + admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden + client = TestClient(admin_app) + resp = client.get("/api/v1/admin/kb/documents") + assert resp.status_code == 403 + + def test_non_admin_cannot_upload_document(self, admin_app: FastAPI): + admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden + client = TestClient(admin_app) + resp = client.post( + "/api/v1/admin/kb/documents", + json={"filename": "x.txt", "content": "x"}, + ) + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# KB source sync/rebuild +# --------------------------------------------------------------------------- + + +class TestKbSourceRoutes: + def test_sync_nonexistent_source_returns_404(self, admin_client: TestClient): + resp = admin_client.post("/api/v1/admin/kb/sources/nonexistent/sync") + assert resp.status_code == 404 + + def test_rebuild_nonexistent_source_returns_404(self, admin_client: TestClient): + resp = admin_client.post("/api/v1/admin/kb/sources/nonexistent/rebuild") + assert resp.status_code == 404 + + def test_sync_existing_source_returns_200(self, admin_client: TestClient): + # Add a source via the KbService (bypassing the route layer). + from agentkit.server.admin.kb_service import get_kb_service + + svc = get_kb_service() + store = svc._resolve_store() + source = store.add_source("Test Source", "local", {}) + + resp = admin_client.post(f"/api/v1/admin/kb/sources/{source.id}/sync") + assert resp.status_code == 200, resp.text + body = resp.json() + assert body["status"] == "syncing" + + def test_rebuild_existing_source_returns_200(self, admin_client: TestClient): + from agentkit.server.admin.kb_service import get_kb_service + + svc = get_kb_service() + store = svc._resolve_store() + source = store.add_source("Test Source", "local", {}) + + resp = admin_client.post(f"/api/v1/admin/kb/sources/{source.id}/rebuild") + assert resp.status_code == 200, resp.text + body = resp.json() + assert body["status"] == "rebuilding" + + def test_non_admin_cannot_sync_source(self, admin_app: FastAPI): + admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden + client = TestClient(admin_app) + resp = client.post("/api/v1/admin/kb/sources/any/sync") + assert resp.status_code == 403 diff --git a/tests/unit/admin/test_models.py b/tests/unit/admin/test_models.py index 82e8bae..ff97ef0 100644 --- a/tests/unit/admin/test_models.py +++ b/tests/unit/admin/test_models.py @@ -4,7 +4,7 @@ Covers: - ``init_auth_db`` creates the new V3 tables (departments, user_departments, department_skill_bindings, department_kb_bindings, department_quotas) - ``init_auth_db`` is idempotent (calling twice does not error) -- ``_SCHEMA_VERSION`` is recorded as 3 in ``auth_meta`` +- ``_SCHEMA_VERSION`` is recorded as 4 in ``auth_meta`` (V4 adds skill_states) - ``departments`` insert + query round-trip - ``user_departments`` many-to-many relationship (one user → many departments, one department → many users) @@ -109,9 +109,7 @@ async def _list_index_names(db: aiosqlite.Connection, table: str) -> set[str]: async def _list_table_names(db: aiosqlite.Connection) -> set[str]: """Return the set of table names in the SQLite file.""" - cursor = await db.execute( - "SELECT name FROM sqlite_master WHERE type='table'" - ) + cursor = await db.execute("SELECT name FROM sqlite_master WHERE type='table'") rows = await cursor.fetchall() return {row[0] for row in rows} @@ -122,9 +120,9 @@ async def _list_table_names(db: aiosqlite.Connection) -> set[str]: class TestSchemaVersion: - def test_schema_version_is_v3(self): - """V3 adds the department-scoped admin tables.""" - assert _SCHEMA_VERSION == 3 + def test_schema_version_is_v4(self): + """V4 adds the skill_states table (U6 — Admin Console).""" + assert _SCHEMA_VERSION == 4 def test_sqlalchemy_model_table_names(self): assert DepartmentModel.__tablename__ == "departments" @@ -165,13 +163,13 @@ class TestInitAuthDbTables: tables = await _list_table_names(db) assert "department_quotas" in tables - async def test_records_schema_version_3_in_auth_meta(self, fresh_db: Path): + async def test_records_schema_version_4_in_auth_meta(self, fresh_db: Path): async with aiosqlite.connect(str(fresh_db)) as db: db.row_factory = aiosqlite.Row cursor = await db.execute("SELECT value FROM auth_meta WHERE key='schema_version'") row = await cursor.fetchone() assert row is not None - assert row["value"] == "3" + assert row["value"] == "4" assert row["value"] == str(_SCHEMA_VERSION) async def test_init_auth_db_is_idempotent(self, tmp_path: Path): @@ -279,9 +277,7 @@ class TestDepartmentsCrud: ) await db.commit() db.row_factory = aiosqlite.Row - cursor = await db.execute( - "SELECT description FROM departments WHERE id=?", (dept_id,) - ) + cursor = await db.execute("SELECT description FROM departments WHERE id=?", (dept_id,)) row = await cursor.fetchone() assert row is not None assert row["description"] is None @@ -310,8 +306,7 @@ class TestUserDepartmentsManyToMany: await db.commit() db.row_factory = aiosqlite.Row cursor = await db.execute( - "SELECT department_id FROM user_departments WHERE user_id=? " - "ORDER BY department_id", + "SELECT department_id FROM user_departments WHERE user_id=? ORDER BY department_id", (user_id,), ) rows = await cursor.fetchall() @@ -336,8 +331,7 @@ class TestUserDepartmentsManyToMany: await db.commit() db.row_factory = aiosqlite.Row cursor = await db.execute( - "SELECT user_id FROM user_departments WHERE department_id=? " - "ORDER BY user_id", + "SELECT user_id FROM user_departments WHERE department_id=? ORDER BY user_id", (dept_id,), ) rows = await cursor.fetchall() @@ -394,9 +388,7 @@ class TestDepartmentSkillBindingsUnique: ) await db.commit() - async def test_same_skill_name_in_different_departments_is_allowed( - self, fresh_db: Path - ): + async def test_same_skill_name_in_different_departments_is_allowed(self, fresh_db: Path): dept_a = str(uuid.uuid4()) dept_b = str(uuid.uuid4()) async with aiosqlite.connect(str(fresh_db)) as db: diff --git a/tests/unit/admin/test_skill_service.py b/tests/unit/admin/test_skill_service.py new file mode 100644 index 0000000..1a81f09 --- /dev/null +++ b/tests/unit/admin/test_skill_service.py @@ -0,0 +1,386 @@ +"""Unit tests for SkillService (U6 — skill enable/disable + import/reload). + +Covers: +- disable_skill → is_skill_disabled returns True +- enable_skill → is_skill_disabled returns False +- list_disabled_skills → returns correct list +- import_skill with valid YAML → file written, skill loaded +- import_skill with invalid YAML → ValueError +- import_skill with path traversal attempt → ValueError +- reload_skill → unregister + reload +- update_skill_config → YAML updated + reloaded +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from agentkit.server.admin.skill_service import SkillService, _validate_skill_name +from agentkit.server.auth.models import init_auth_db +from agentkit.skills.registry import SkillRegistry + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +async def fresh_db(tmp_path: Path) -> Path: + """A brand-new auth DB on a fresh path (no data).""" + db_path = tmp_path / "auth.db" + await init_auth_db(db_path) + return db_path + + +@pytest.fixture +def service() -> SkillService: + return SkillService() + + +@pytest.fixture +def skills_dir(tmp_path: Path) -> str: + """A temp skills directory for YAML files.""" + d = tmp_path / "skills" + d.mkdir() + return str(d) + + +@pytest.fixture +def skill_registry() -> SkillRegistry: + return SkillRegistry() + + +_VALID_SKILL_YAML = """\ +name: test_skill +agent_type: simple_generation +version: "1.0.0" +description: "A test skill for unit testing" +task_mode: llm_generate +execution_mode: direct +max_steps: 1 +prompt: + identity: "Test" + instructions: "Handle test" +tools: [] +""" + + +# --------------------------------------------------------------------------- +# Enable / disable +# --------------------------------------------------------------------------- + + +class TestDisableEnable: + async def test_disable_skill_marks_disabled(self, service: SkillService, fresh_db: Path): + result = await service.disable_skill(fresh_db, "test_skill", disabled_by="admin-1") + assert result["skill_name"] == "test_skill" + assert result["is_disabled"] is True + assert result["disabled_by"] == "admin-1" + assert "disabled_at" in result + + assert await service.is_skill_disabled(fresh_db, "test_skill") is True + + async def test_disable_skill_normalizes_name(self, service: SkillService, fresh_db: Path): + """Skill names are normalized to lowercase before storage.""" + await service.disable_skill(fresh_db, "TestSkill") + assert await service.is_skill_disabled(fresh_db, "testskill") is True + assert await service.is_skill_disabled(fresh_db, "TestSkill") is True + + async def test_disable_skill_is_idempotent(self, service: SkillService, fresh_db: Path): + """Disabling an already-disabled skill updates the row, not duplicates.""" + await service.disable_skill(fresh_db, "skill_a", disabled_by="admin-1") + await service.disable_skill(fresh_db, "skill_a", disabled_by="admin-2") + disabled = await service.list_disabled_skills(fresh_db) + assert disabled == ["skill_a"] + + async def test_enable_skill_removes_disabled_mark(self, service: SkillService, fresh_db: Path): + await service.disable_skill(fresh_db, "skill_b") + assert await service.is_skill_disabled(fresh_db, "skill_b") is True + + enabled = await service.enable_skill(fresh_db, "skill_b") + assert enabled is True + assert await service.is_skill_disabled(fresh_db, "skill_b") is False + + async def test_enable_skill_returns_false_if_not_disabled( + self, service: SkillService, fresh_db: Path + ): + """Enabling a skill that wasn't disabled returns False (no-op).""" + enabled = await service.enable_skill(fresh_db, "never_disabled") + assert enabled is False + + async def test_is_skill_disabled_returns_false_for_unknown( + self, service: SkillService, fresh_db: Path + ): + assert await service.is_skill_disabled(fresh_db, "unknown_skill") is False + + async def test_list_disabled_skills_returns_sorted_list( + self, service: SkillService, fresh_db: Path + ): + await service.disable_skill(fresh_db, "charlie") + await service.disable_skill(fresh_db, "alpha") + await service.disable_skill(fresh_db, "bravo") + result = await service.list_disabled_skills(fresh_db) + assert result == ["alpha", "bravo", "charlie"] + + async def test_list_disabled_skills_empty_when_none_disabled( + self, service: SkillService, fresh_db: Path + ): + result = await service.list_disabled_skills(fresh_db) + assert result == [] + + async def test_disable_skill_invalid_name_raises(self, service: SkillService, fresh_db: Path): + with pytest.raises(ValueError, match="Invalid skill name"): + await service.disable_skill(fresh_db, "Invalid Name With Spaces") + + async def test_disable_skill_uppercase_name_normalizes( + self, service: SkillService, fresh_db: Path + ): + """Uppercase names are normalized to lowercase (regex requires lowercase).""" + # 'TEST_SKILL' normalizes to 'test_skill' which matches the regex. + await service.disable_skill(fresh_db, "TEST_SKILL") + assert await service.is_skill_disabled(fresh_db, "test_skill") is True + + +# --------------------------------------------------------------------------- +# import_skill +# --------------------------------------------------------------------------- + + +class TestImportSkill: + async def test_import_skill_writes_file_and_loads( + self, + service: SkillService, + skills_dir: str, + skill_registry: SkillRegistry, + ): + result = await service.import_skill( + _VALID_SKILL_YAML, + skills_dir, + skill_registry=skill_registry, + ) + assert result["name"] == "test_skill" + assert Path(result["path"]).is_file() + # Skill should be registered in the registry. + assert skill_registry.has_skill("test_skill") + + async def test_import_skill_without_registry_writes_file_only( + self, + service: SkillService, + skills_dir: str, + ): + result = await service.import_skill(_VALID_SKILL_YAML, skills_dir) + assert result["name"] == "test_skill" + assert Path(result["path"]).is_file() + + async def test_import_skill_invalid_yaml_raises( + self, + service: SkillService, + skills_dir: str, + ): + with pytest.raises(ValueError, match="Invalid YAML"): + await service.import_skill("not: valid: yaml: [", skills_dir) + + async def test_import_skill_non_mapping_yaml_raises( + self, + service: SkillService, + skills_dir: str, + ): + with pytest.raises(ValueError, match="mapping"): + await service.import_skill("- just\n- a\n- list\n", skills_dir) + + async def test_import_skill_missing_name_field_raises( + self, + service: SkillService, + skills_dir: str, + ): + yaml_without_name = "agent_type: simple_generation\n" + with pytest.raises(ValueError, match="name"): + await service.import_skill(yaml_without_name, skills_dir) + + async def test_import_skill_invalid_name_raises( + self, + service: SkillService, + skills_dir: str, + ): + """YAML with a name that fails the regex raises ValueError.""" + bad_yaml = 'name: "Bad Name With Spaces"\nagent_type: test\n' + with pytest.raises(ValueError, match="Invalid skill name"): + await service.import_skill(bad_yaml, skills_dir) + + async def test_import_skill_path_traversal_blocked( + self, + service: SkillService, + skills_dir: str, + ): + """A YAML with a name containing path separators is rejected by the regex.""" + # The regex `^[a-z0-9][a-z0-9_-]{0,63}$` rejects '/' and '..', + # so path traversal via the name field is impossible. We verify + # this by attempting to import a YAML with a traversal name. + traversal_yaml = 'name: "../etc/passwd"\nagent_type: test\n' + with pytest.raises(ValueError, match="Invalid skill name"): + await service.import_skill(traversal_yaml, skills_dir) + + +# --------------------------------------------------------------------------- +# reload_skill +# --------------------------------------------------------------------------- + + +class TestReloadSkill: + async def test_reload_skill_unregisters_and_reloads( + self, + service: SkillService, + skills_dir: str, + skill_registry: SkillRegistry, + ): + # First import the skill. + await service.import_skill( + _VALID_SKILL_YAML, + skills_dir, + skill_registry=skill_registry, + ) + assert skill_registry.has_skill("test_skill") + + # Now reload it. + result = await service.reload_skill("test_skill", skill_registry, skills_dir) + assert result["name"] == "test_skill" + assert result["status"] == "reloaded" + assert skill_registry.has_skill("test_skill") + + async def test_reload_skill_missing_yaml_raises( + self, + service: SkillService, + skills_dir: str, + skill_registry: SkillRegistry, + ): + with pytest.raises(ValueError, match="not found"): + await service.reload_skill("nonexistent", skill_registry, skills_dir) + + async def test_reload_skill_invalid_name_raises( + self, + service: SkillService, + skills_dir: str, + skill_registry: SkillRegistry, + ): + with pytest.raises(ValueError, match="Invalid skill name"): + await service.reload_skill("Bad Name", skill_registry, skills_dir) + + +# --------------------------------------------------------------------------- +# update_skill_config +# --------------------------------------------------------------------------- + + +class TestUpdateSkillConfig: + async def test_update_skill_config_updates_yaml_and_reloads( + self, + service: SkillService, + skills_dir: str, + skill_registry: SkillRegistry, + ): + # First import the skill. + await service.import_skill( + _VALID_SKILL_YAML, + skills_dir, + skill_registry=skill_registry, + ) + + # Patch the description. + result = await service.update_skill_config( + "test_skill", + {"description": "Updated description"}, + skills_dir, + skill_registry, + ) + assert result["name"] == "test_skill" + assert result["status"] == "updated" + + # Verify the YAML file was updated. + import yaml + + with open(result["path"], encoding="utf-8") as f: + data = yaml.safe_load(f) + assert data["description"] == "Updated description" + # Original fields should be preserved. + assert data["name"] == "test_skill" + assert data["agent_type"] == "simple_generation" + + # Skill should still be registered. + assert skill_registry.has_skill("test_skill") + + async def test_update_skill_config_missing_yaml_raises( + self, + service: SkillService, + skills_dir: str, + skill_registry: SkillRegistry, + ): + with pytest.raises(ValueError, match="not found"): + await service.update_skill_config( + "nonexistent", + {"description": "x"}, + skills_dir, + skill_registry, + ) + + async def test_update_skill_config_preserves_name( + self, + service: SkillService, + skills_dir: str, + skill_registry: SkillRegistry, + ): + """Patching the 'name' field is ignored — name is preserved.""" + await service.import_skill( + _VALID_SKILL_YAML, + skills_dir, + skill_registry=skill_registry, + ) + + result = await service.update_skill_config( + "test_skill", + {"name": "should_be_ignored", "description": "new"}, + skills_dir, + skill_registry, + ) + import yaml + + with open(result["path"], encoding="utf-8") as f: + data = yaml.safe_load(f) + # Name should be preserved (not renamed). + assert data["name"] == "test_skill" + assert data["description"] == "new" + + +# --------------------------------------------------------------------------- +# _validate_skill_name helper +# --------------------------------------------------------------------------- + + +class TestValidateSkillName: + def test_valid_name_returns_normalized(self): + assert _validate_skill_name("test_skill") == "test_skill" + assert _validate_skill_name("TestSkill") == "testskill" + assert _validate_skill_name(" spaced ") == "spaced" + assert _validate_skill_name("a-b_c-1") == "a-b_c-1" + + def test_invalid_name_raises(self): + with pytest.raises(ValueError): + _validate_skill_name("Has Spaces") + with pytest.raises(ValueError): + _validate_skill_name("") + with pytest.raises(ValueError): + _validate_skill_name("-leading-dash") + with pytest.raises(ValueError): + _validate_skill_name("../traversal") + with pytest.raises(ValueError): + _validate_skill_name("has.dot") + with pytest.raises(ValueError): + _validate_skill_name("has/slash") + + def test_non_string_raises(self): + with pytest.raises(ValueError): + _validate_skill_name(None) # type: ignore[arg-type] + with pytest.raises(ValueError): + _validate_skill_name(123) # type: ignore[arg-type] diff --git a/tests/unit/auth/test_models.py b/tests/unit/auth/test_models.py index d3d33b1..a2296e7 100644 --- a/tests/unit/auth/test_models.py +++ b/tests/unit/auth/test_models.py @@ -118,9 +118,9 @@ async def _list_index_names(db: aiosqlite.Connection, table: str) -> set[str]: class TestSchemaVersion: - def test_schema_version_is_v3(self): - """The current schema version is 3 (V3 adds department-scoped admin tables).""" - assert _SCHEMA_VERSION == 3 + def test_schema_version_is_v4(self): + """The current schema version is 4 (V4 adds skill_states table).""" + assert _SCHEMA_VERSION == 4 def test_sqlalchemy_model_table_name(self): assert AuthSessionModel.__tablename__ == "auth_sessions" diff --git a/tests/unit/server/test_skill_management.py b/tests/unit/server/test_skill_management.py index aeae30c..6ef9107 100644 --- a/tests/unit/server/test_skill_management.py +++ b/tests/unit/server/test_skill_management.py @@ -2,13 +2,12 @@ from __future__ import annotations -from unittest.mock import AsyncMock - import pytest from fastapi.testclient import TestClient from agentkit.llm.gateway import LLMGateway from agentkit.server.app import create_app +from agentkit.server.config import ServerConfig from agentkit.skills.base import Skill, SkillConfig from agentkit.skills.registry import SkillRegistry from agentkit.tools.registry import ToolRegistry @@ -202,7 +201,31 @@ class TestListCapabilities: class TestReloadSkill: - def test_reload_skill(self, client, skill_registry): + def test_reload_skill(self, client, skill_registry, tmp_path): + # Write a valid YAML file for "reload_skill" to a temp skills dir. + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + yaml_content = ( + "name: reload_skill\n" + 'agent_type: "test_type"\n' + 'version: "1.0.0"\n' + 'description: "Skill for reload testing"\n' + "task_mode: llm_generate\n" + "execution_mode: direct\n" + "max_steps: 1\n" + "intent:\n" + ' keywords: ["reload"]\n' + ' description: "reload test"\n' + "prompt:\n" + ' identity: "reload tester"\n' + ' instructions: "handle reload"\n' + "tools: []\n" + ) + (skills_dir / "reload_skill.yaml").write_text(yaml_content, encoding="utf-8") + + # Point the app at the temp skills dir and register the skill so the + # 404 guard in the route passes. + client.app.state.server_config = ServerConfig(skill_paths=[str(skills_dir)]) _register_skill(skill_registry, "reload_skill") response = client.post("/api/v1/skill-management/skills/reload_skill/reload")