feat(admin): U6 — Skill & KB management endpoints + department binding

SkillService: enable/disable (persisted in skill_states table, schema
v4), import from YAML (with path traversal + name validation), reload
from file, update config. GET /skills now filters disabled skills.

KbService: list/upload/delete documents with department_id binding.
Added department_id field to KnowledgeSource + UploadedDocument.
Department visibility: (bound to user depts) ∪ (global = None).

10 new admin endpoints: skill enable/disable/import/reload/update,
KB documents CRUD, source sync/rebuild. All guarded by _require_admin.

Implemented reload stub in skill_management.py (was no-op).

54 new tests (26 unit + 28 integration). Fixed 4 pre-existing lint
errors. 357 admin tests pass, no regressions.
This commit is contained in:
chiguyong 2026-06-21 16:19:51 +08:00
parent 980919fc95
commit fd7f6816b8
12 changed files with 2070 additions and 55 deletions

View File

@ -0,0 +1,246 @@
"""KbService — admin-driven KB document/source management (U6).
This module wraps the existing in-memory :class:`KnowledgeSourceStore`
from :mod:`agentkit.server.routes.kb_management` so admin routes can
manage documents and sources through a service layer (mirroring
:class:`DepartmentService` and :class:`SkillService`).
Department filtering
--------------------
When listing documents, the service filters by department visibility:
- documents whose ``department_id`` is in the caller's department set, OR
- documents whose ``department_id`` is ``None`` (global documents).
Admin callers (``department_ids=None``) bypass filtering and see all
documents.
The service is a module-level singleton (see :func:`get_kb_service`)
so tests can inject a custom instance via :func:`set_kb_service`.
"""
from __future__ import annotations
import logging
import uuid
from datetime import datetime, timezone
from typing import Any
from agentkit.server.routes.kb_management import (
KnowledgeSource,
KnowledgeSourceStore,
UploadedDocument,
get_source_store,
)
logger = logging.getLogger(__name__)
def _now_iso() -> str:
"""Return current UTC time as ISO 8601 string."""
return datetime.now(timezone.utc).isoformat()
class KbService:
"""Wraps :class:`KnowledgeSourceStore` with admin-friendly operations.
The service does NOT own the store it delegates to the
module-level singleton from :mod:`kb_management`. This keeps a
single source of truth for KB state (the existing routes still
work) while giving admin routes a clean service boundary.
Department filtering is applied at the service layer so admin
routes can pass ``department_ids`` directly from the
:class:`DepartmentContext` dependency.
"""
def __init__(self, store: KnowledgeSourceStore | None = None) -> None:
self._store = store
def _resolve_store(self) -> KnowledgeSourceStore:
"""Return the injected store, or the module singleton."""
if self._store is not None:
return self._store
return get_source_store()
# ------------------------------------------------------------------
# Document operations
# ------------------------------------------------------------------
def list_documents(
self,
department_ids: list[str] | None = None,
source_id: str | None = None,
) -> list[dict[str, Any]]:
"""List documents, optionally filtered by source and department.
Args:
department_ids: Caller's department ids. ``None`` means admin
(no filtering return all documents). An empty list means
a user with no departments (return only global documents).
source_id: Optional source id filter.
Returns:
List of document dicts (JSON-safe).
"""
store = self._resolve_store()
documents = store.list_documents(source_id=source_id)
# Admin bypass: department_ids=None means "see everything".
if department_ids is None:
return [self._doc_to_dict(d) for d in documents]
# Non-admin: visible = (docs in user's departments) (global docs).
dept_set = set(department_ids)
visible = [d for d in documents if d.department_id is None or d.department_id in dept_set]
return [self._doc_to_dict(d) for d in visible]
def upload_document(
self,
filename: str,
content: bytes,
source_id: str = "",
department_id: str | None = None,
) -> dict[str, Any]:
"""Upload a document with optional department binding.
Args:
filename: Original filename.
content: File content bytes (used to estimate chunk count).
source_id: Optional source id. Defaults to ``"local"``.
department_id: Optional department id to bind the document to.
``None`` means the document is global (visible to all).
Returns:
The uploaded document dict.
"""
store = self._resolve_store()
effective_source_id = source_id or "local"
# Ensure a local source exists for default uploads.
if effective_source_id == "local":
local_sources = [s for s in store.list_sources() if s.type == "local"]
if not local_sources:
store.add_source("本地文档", "local", {})
# Estimate chunks from content length (rough approximation,
# mirrors the existing /kb-management/documents/upload route).
chunks = max(1, len(content) // 500)
doc = UploadedDocument(
document_id=str(uuid.uuid4()),
filename=filename,
source_id=effective_source_id,
chunks=chunks,
status="indexed",
department_id=department_id,
)
store.add_document(doc)
return self._doc_to_dict(doc)
def delete_document(self, document_id: str) -> bool:
"""Delete a document by id. Returns ``True`` if deleted."""
store = self._resolve_store()
return store.delete_document(document_id)
# ------------------------------------------------------------------
# Source operations
# ------------------------------------------------------------------
def list_sources(self) -> list[dict[str, Any]]:
"""List all KB sources."""
store = self._resolve_store()
return [self._source_to_dict(s) for s in store.list_sources()]
def get_source(self, source_id: str) -> dict[str, Any] | None:
"""Return a single source by id, or ``None`` if not found."""
store = self._resolve_store()
source = store.get_source(source_id)
return self._source_to_dict(source) if source else None
def sync_source(self, source_id: str) -> dict[str, Any]:
"""Trigger a source sync (stub — marks status as ``syncing``).
Raises:
ValueError: If the source does not exist.
"""
store = self._resolve_store()
source = store.get_source(source_id)
if source is None:
raise ValueError(f"KB source {source_id!r} not found")
# Stub: mark status as syncing. A real implementation would
# kick off an async sync job and return a job id.
source.status = "syncing"
source.last_synced = _now_iso()
return {
"source_id": source_id,
"status": "syncing",
"message": "Sync started",
}
def rebuild_index(self, source_id: str) -> dict[str, Any]:
"""Rebuild the index for a source (stub — marks status).
Raises:
ValueError: If the source does not exist.
"""
store = self._resolve_store()
source = store.get_source(source_id)
if source is None:
raise ValueError(f"KB source {source_id!r} not found")
# Stub: mark status as rebuilding.
source.status = "rebuilding"
return {
"source_id": source_id,
"status": "rebuilding",
"message": "Index rebuild started",
}
# ------------------------------------------------------------------
# Serialization helpers
# ------------------------------------------------------------------
def _doc_to_dict(self, doc: UploadedDocument) -> dict[str, Any]:
"""Convert an :class:`UploadedDocument` to a JSON-safe dict."""
return {
"document_id": doc.document_id,
"filename": doc.filename,
"source_id": doc.source_id,
"chunks": doc.chunks,
"status": doc.status,
"created_at": doc.created_at,
"department_id": doc.department_id,
}
def _source_to_dict(self, source: KnowledgeSource) -> dict[str, Any]:
"""Convert a :class:`KnowledgeSource` to a JSON-safe dict."""
return {
"id": source.id,
"name": source.name,
"type": source.type,
"status": source.status,
"document_count": source.document_count,
"last_synced": source.last_synced,
"department_id": source.department_id,
}
# ---------------------------------------------------------------------------
# Module-level singleton (overridable in tests via set_kb_service)
# ---------------------------------------------------------------------------
_kb_service: KbService | None = None
def get_kb_service() -> KbService:
"""Return the process-wide :class:`KbService` (lazy singleton)."""
global _kb_service
if _kb_service is None:
_kb_service = KbService()
return _kb_service
def set_kb_service(service: KbService | None) -> None:
"""Inject a custom :class:`KbService` (used by tests)."""
global _kb_service
_kb_service = service

View File

@ -0,0 +1,428 @@
"""SkillService — admin-driven skill enable/disable + import/reload (U6).
This module is the single owner of the ``skill_states`` table (admin
disable markers) and the YAML import/reload pipeline. Web UI routes
(``/api/v1/admin/skills/*``) and the existing skill-management route
both call into :class:`SkillService` rather than touching the table
or filesystem directly, keeping validation rules (skill-name regex,
path-traversal guard, YAML validation) in one place.
The service is a module-level singleton (see :func:`get_skill_service`)
so tests can inject a custom instance via :func:`set_skill_service`.
Design notes
------------
- Each method opens its own short-lived :class:`aiosqlite.Connection`
(mirrors :class:`DepartmentService`).
- ``import_skill`` reuses the validation regex from
:mod:`agentkit.server.routes.skills` but inlines a copy to avoid a
circular import (routes.skills admin.skill_service routes.skills).
- ``reload_skill`` and ``update_skill_config`` accept a
:class:`SkillRegistry` instance so the caller can pass the live
registry from ``app.state``.
"""
from __future__ import annotations
import logging
import os
import re
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any
import aiosqlite
import yaml
from agentkit.server.auth.models import skill_state_row_to_dict
if TYPE_CHECKING:
from agentkit.skills.registry import SkillRegistry
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants & helpers
# ---------------------------------------------------------------------------
# Strict skill name validation: lowercase alphanumeric, hyphens, underscores.
# Inlined from routes/skills.py to avoid a circular import.
_SKILL_NAME_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
def _now_iso() -> str:
"""Return current UTC time as ISO 8601 string."""
return datetime.now(timezone.utc).isoformat()
def _validate_skill_name(name: str) -> str:
"""Validate and normalize a skill name. Raises ``ValueError`` on invalid input.
Mirrors the regex in :mod:`agentkit.server.routes.skills` but raises
``ValueError`` (not ``HTTPException``) so the service is HTTP-agnostic.
"""
if not isinstance(name, str):
raise ValueError("Skill name must be a string")
normalized = name.strip().lower()
if not _SKILL_NAME_RE.match(normalized):
raise ValueError(
f"Invalid skill name {name!r}: must contain only lowercase letters, "
f"digits, hyphens, and underscores (1-64 chars)"
)
return normalized
def _validate_yaml_content(content: str) -> dict[str, Any]:
"""Validate YAML content. Returns parsed dict. Raises ``ValueError`` on invalid input.
Mirrors the validation in :mod:`agentkit.server.routes.skills` but
raises ``ValueError`` (not ``HTTPException``).
"""
try:
data = yaml.safe_load(content)
except yaml.YAMLError as exc:
raise ValueError(f"Invalid YAML content: {exc}") from exc
if not isinstance(data, dict):
raise ValueError("Skill YAML must be a mapping/dict")
if "name" not in data:
raise ValueError("Skill YAML must contain a 'name' field")
return data
def _is_within_directory(file_path: str, directory: str) -> bool:
"""Return ``True`` if ``file_path`` resolves inside ``directory``.
Uses :func:`os.path.realpath` to canonicalize both paths before
comparison, so symlink traversal and ``..`` segments are detected.
"""
real_file = os.path.realpath(file_path)
real_dir = os.path.realpath(directory)
return real_file.startswith(real_dir + os.sep) or real_file == real_dir
# ---------------------------------------------------------------------------
# Service
# ---------------------------------------------------------------------------
class SkillService:
"""Admin-driven skill enable/disable + import/reload operations.
All DB-touching methods are async and take ``db_path: Path`` as the
first argument. YAML-touching methods take ``skills_dir: str`` and
optionally a :class:`SkillRegistry` instance for live reload.
"""
# ------------------------------------------------------------------
# Enable / disable (DB state)
# ------------------------------------------------------------------
async def disable_skill(
self,
db_path: Path,
skill_name: str,
disabled_by: str | None = None,
) -> dict[str, Any]:
"""Mark a skill as disabled in the DB.
Idempotent: if the skill is already disabled, the row is updated
(refreshing ``disabled_at`` and ``disabled_by``).
Args:
db_path: Path to the auth SQLite DB.
skill_name: Skill name to disable (validated against the regex).
disabled_by: Optional admin user id (audit trail).
Returns:
The disabled skill state dict.
"""
normalized = _validate_skill_name(skill_name)
now = _now_iso()
async with aiosqlite.connect(str(db_path)) as db:
db.row_factory = aiosqlite.Row
await db.execute(
"INSERT INTO skill_states (skill_name, is_disabled, disabled_at, disabled_by) "
"VALUES (?, 1, ?, ?) "
"ON CONFLICT(skill_name) DO UPDATE SET "
" is_disabled = excluded.is_disabled, "
" disabled_at = excluded.disabled_at, "
" disabled_by = excluded.disabled_by",
(normalized, now, disabled_by),
)
await db.commit()
cursor = await db.execute(
"SELECT * FROM skill_states WHERE skill_name = ?",
(normalized,),
)
row = await cursor.fetchone()
assert row is not None # we just upserted it
return skill_state_row_to_dict(row)
async def enable_skill(
self,
db_path: Path,
skill_name: str,
) -> bool:
"""Remove the disabled mark for a skill.
Returns:
``True`` if a row was deleted (skill was previously disabled),
``False`` if the skill was not disabled.
"""
normalized = _validate_skill_name(skill_name)
async with aiosqlite.connect(str(db_path)) as db:
cursor = await db.execute(
"DELETE FROM skill_states WHERE skill_name = ?",
(normalized,),
)
await db.commit()
return cursor.rowcount > 0
async def is_skill_disabled(
self,
db_path: Path,
skill_name: str,
) -> bool:
"""Return ``True`` if the skill is currently disabled."""
normalized = _validate_skill_name(skill_name)
async with aiosqlite.connect(str(db_path)) as db:
cursor = await db.execute(
"SELECT is_disabled FROM skill_states WHERE skill_name = ?",
(normalized,),
)
row = await cursor.fetchone()
if row is None:
return False
return bool(row[0])
async def list_disabled_skills(
self,
db_path: Path,
) -> list[str]:
"""Return the list of skill names currently disabled."""
async with aiosqlite.connect(str(db_path)) as db:
cursor = await db.execute(
"SELECT skill_name FROM skill_states WHERE is_disabled = 1 ORDER BY skill_name ASC"
)
rows = await cursor.fetchall()
return [row[0] for row in rows]
# ------------------------------------------------------------------
# YAML import / reload / update
# ------------------------------------------------------------------
async def import_skill(
self,
yaml_content: str,
skills_dir: str,
skill_registry: SkillRegistry | None = None,
tool_registry: Any = None,
) -> dict[str, Any]:
"""Validate YAML, write to file, and load into the registry.
Args:
yaml_content: Raw YAML text for the skill config.
skills_dir: Target directory for the YAML file. The file is
written as ``{skills_dir}/{skill_name}.yaml``.
skill_registry: Optional :class:`SkillRegistry` to register
the loaded skill into. If ``None``, the skill is only
written to disk.
tool_registry: Optional :class:`ToolRegistry` for tool binding.
Returns:
Dict with ``name`` and ``path`` of the imported skill.
Raises:
ValueError: If the YAML is invalid, the skill name is invalid,
or the resolved path escapes ``skills_dir``.
"""
data = _validate_yaml_content(yaml_content)
raw_name = data["name"]
skill_name = _validate_skill_name(raw_name)
os.makedirs(skills_dir, exist_ok=True)
file_path = os.path.join(skills_dir, f"{skill_name}.yaml")
# Path traversal protection: the resolved path must stay inside
# skills_dir. (skill_name is regex-validated so this should always
# pass, but we double-check defensively.)
if not _is_within_directory(file_path, skills_dir):
raise ValueError(f"Resolved skill path escapes skills_dir: {file_path}")
with open(file_path, "w", encoding="utf-8") as f:
f.write(yaml_content)
if skill_registry is not None:
try:
from agentkit.skills.loader import SkillLoader
loader = SkillLoader(
skill_registry=skill_registry,
tool_registry=tool_registry,
)
loader.load_from_file(file_path)
except Exception:
# Remove the invalid YAML file and re-raise as ValueError
# so the route layer maps it to 400.
try:
os.remove(file_path)
except OSError:
pass
raise ValueError("Skill YAML written but registration failed")
return {"name": skill_name, "path": file_path}
async def reload_skill(
self,
skill_name: str,
skill_registry: SkillRegistry,
skills_dir: str,
tool_registry: Any = None,
) -> dict[str, Any]:
"""Unregister and reload a skill from its YAML file.
Args:
skill_name: Skill name to reload (validated against the regex).
skill_registry: The live :class:`SkillRegistry` instance.
skills_dir: Directory containing the YAML file.
tool_registry: Optional :class:`ToolRegistry` for tool binding.
Returns:
Dict with ``name``, ``path``, and ``status`` of the reload.
Raises:
ValueError: If the skill name is invalid or the YAML file
cannot be found.
"""
normalized = _validate_skill_name(skill_name)
file_path = os.path.join(skills_dir, f"{normalized}.yaml")
if not os.path.isfile(file_path):
raise ValueError(f"Skill {normalized!r} not found: no YAML file at {file_path}")
if not _is_within_directory(file_path, skills_dir):
raise ValueError(f"Resolved skill path escapes skills_dir: {file_path}")
# Unregister the existing skill (if any) before reloading so the
# registry picks up the new config cleanly.
try:
skill_registry.unregister(normalized)
except Exception: # noqa: BLE001 — unregister is best-effort
logger.debug("Skill %r was not registered before reload", normalized)
from agentkit.skills.loader import SkillLoader
loader = SkillLoader(
skill_registry=skill_registry,
tool_registry=tool_registry,
)
loader.load_from_file(file_path)
return {
"name": normalized,
"path": file_path,
"status": "reloaded",
}
async def update_skill_config(
self,
skill_name: str,
config_patch: dict[str, Any],
skills_dir: str,
skill_registry: SkillRegistry,
tool_registry: Any = None,
) -> dict[str, Any]:
"""Update a skill's YAML config in-place and reload it.
Performs a shallow merge: top-level keys in ``config_patch``
overwrite the existing YAML values. Nested dicts are replaced
wholesale (not deep-merged) to keep the semantics predictable.
Args:
skill_name: Skill name to update (validated against the regex).
config_patch: Dict of top-level YAML keys to update.
skills_dir: Directory containing the YAML file.
skill_registry: The live :class:`SkillRegistry` instance.
tool_registry: Optional :class:`ToolRegistry` for tool binding.
Returns:
Dict with ``name``, ``path``, and ``status`` of the update.
Raises:
ValueError: If the skill name is invalid, the YAML file
cannot be found, or the patched config is invalid.
"""
normalized = _validate_skill_name(skill_name)
file_path = os.path.join(skills_dir, f"{normalized}.yaml")
if not os.path.isfile(file_path):
raise ValueError(f"Skill {normalized!r} not found: no YAML file at {file_path}")
if not _is_within_directory(file_path, skills_dir):
raise ValueError(f"Resolved skill path escapes skills_dir: {file_path}")
# Read existing YAML, apply patch, validate, write back.
with open(file_path, encoding="utf-8") as f:
existing = yaml.safe_load(f) or {}
if not isinstance(existing, dict):
raise ValueError(f"Existing skill YAML at {file_path} is not a mapping")
# Shallow merge: top-level keys in config_patch overwrite.
# The 'name' field is preserved (cannot rename via patch).
patched = {**existing, **config_patch, "name": existing.get("name", normalized)}
# Validate the patched config by round-tripping through YAML.
try:
patched_yaml = yaml.safe_dump(patched, default_flow_style=False, allow_unicode=True)
except yaml.YAMLError as exc:
raise ValueError(f"Failed to serialize patched config: {exc}") from exc
_validate_yaml_content(patched_yaml)
with open(file_path, "w", encoding="utf-8") as f:
f.write(patched_yaml)
# Reload the skill into the registry.
try:
skill_registry.unregister(normalized)
except Exception: # noqa: BLE001 — unregister is best-effort
logger.debug("Skill %r was not registered before update", normalized)
from agentkit.skills.loader import SkillLoader
loader = SkillLoader(
skill_registry=skill_registry,
tool_registry=tool_registry,
)
loader.load_from_file(file_path)
return {
"name": normalized,
"path": file_path,
"status": "updated",
}
# ---------------------------------------------------------------------------
# Module-level singleton (overridable in tests via set_skill_service)
# ---------------------------------------------------------------------------
_skill_service: SkillService | None = None
def get_skill_service() -> SkillService:
"""Return the process-wide :class:`SkillService` (lazy singleton)."""
global _skill_service
if _skill_service is None:
_skill_service = SkillService()
return _skill_service
def set_skill_service(service: SkillService | None) -> None:
"""Inject a custom :class:`SkillService` (used by tests)."""
global _skill_service
_skill_service = service

View File

@ -351,6 +351,24 @@ class DepartmentQuotaModel(Base):
updated_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso) updated_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso)
class SkillStateModel(Base):
"""Skill enable/disable state (V4 — Admin Console U6).
Each row records whether a named skill has been disabled by an
admin via the ``/admin/skills/{name}/disable`` endpoint. Skills
with no row here are considered enabled (the default). ``skill_name``
references the skill registry identifier (not a DB FK skills are
defined in YAML configs).
"""
__tablename__ = "skill_states"
skill_name: Mapped[str] = mapped_column(String(128), primary_key=True)
is_disabled: Mapped[bool] = mapped_column(default=True, nullable=False)
disabled_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso)
disabled_by: Mapped[str | None] = mapped_column(String(36), nullable=True)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Schema DDL (kept in sync with the models above for aiosqlite bootstrap) # Schema DDL (kept in sync with the models above for aiosqlite bootstrap)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -565,6 +583,17 @@ CREATE TABLE IF NOT EXISTS department_quotas (
); );
CREATE INDEX IF NOT EXISTS idx_department_quotas_department_id CREATE INDEX IF NOT EXISTS idx_department_quotas_department_id
ON department_quotas(department_id); ON department_quotas(department_id);
-- V4: skill_states records admin-disabled skills (U6 Admin Console).
-- A skill with no row here is considered enabled (the default). Only
-- disabled skills have a row, with is_disabled=1. disabled_by records
-- the admin user id who disabled the skill (audit trail).
CREATE TABLE IF NOT EXISTS skill_states (
skill_name TEXT PRIMARY KEY,
is_disabled INTEGER NOT NULL DEFAULT 1,
disabled_at TEXT NOT NULL,
disabled_by TEXT
);
""" """
@ -580,7 +609,10 @@ CREATE INDEX IF NOT EXISTS idx_department_quotas_department_id
# V3 (2026-06-21, Admin Console): added departments, user_departments, # V3 (2026-06-21, Admin Console): added departments, user_departments,
# department_skill_bindings, department_kb_bindings, department_quotas. # department_skill_bindings, department_kb_bindings, department_quotas.
# No backfill needed — all new tables are additive. # No backfill needed — all new tables are additive.
_SCHEMA_VERSION = 3 #
# V4 (2026-06-21, Admin Console U6): added skill_states table for
# admin-driven skill enable/disable. No backfill needed — additive.
_SCHEMA_VERSION = 4
_META_SCHEMA_VERSION_KEY = "schema_version" _META_SCHEMA_VERSION_KEY = "schema_version"
@ -803,3 +835,18 @@ def user_department_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> di
"department_id": row["department_id"], "department_id": row["department_id"],
"created_at": row["created_at"], "created_at": row["created_at"],
} }
def skill_state_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any]:
"""Convert a ``skill_states`` row into a JSON-safe dict.
The ``is_disabled`` field is normalized to a Python ``bool`` (the DB
stores 0/1). ``disabled_by`` is ``None`` when the disabling admin
is not recorded (e.g. legacy rows or system-initiated disables).
"""
return {
"skill_name": row["skill_name"],
"is_disabled": bool(row["is_disabled"]),
"disabled_at": row["disabled_at"],
"disabled_by": row["disabled_by"],
}

View File

@ -23,11 +23,13 @@ from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from agentkit.server.admin.department_service import get_department_service from agentkit.server.admin.department_service import get_department_service
from agentkit.server.admin.kb_service import get_kb_service
from agentkit.server.admin.llm_config_service import ( from agentkit.server.admin.llm_config_service import (
LlmConfigService, LlmConfigService,
get_llm_config_service, get_llm_config_service,
) )
from agentkit.server.admin.quota_service import get_quota_service from agentkit.server.admin.quota_service import get_quota_service
from agentkit.server.admin.skill_service import get_skill_service
from agentkit.server.admin.user_service import get_user_service from agentkit.server.admin.user_service import get_user_service
from agentkit.server.auth.dependencies import require_authenticated from agentkit.server.auth.dependencies import require_authenticated
from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH, init_auth_db from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH, init_auth_db
@ -854,3 +856,285 @@ async def delete_department_quota(
if not deleted: if not deleted:
raise HTTPException(status_code=404, detail="Quota not found") raise HTTPException(status_code=404, detail="Quota not found")
return {"deleted": True} return {"deleted": True}
# ---------------------------------------------------------------------------
# Skill management endpoints (U6) — enable/disable + import/reload
# ---------------------------------------------------------------------------
def _get_skills_dir(request: Request) -> str:
"""Resolve the skills directory from ``app.state.server_config``.
Mirrors the logic in :func:`agentkit.server.routes.skills._get_skills_dir`
so admin endpoints install/reload skills into the same directory as
the existing ``/skills/install`` route.
"""
server_config = getattr(request.app.state, "server_config", None)
if server_config and getattr(server_config, "skill_paths", None):
first_path = Path(server_config.skill_paths[0])
if first_path.is_dir():
return str(first_path)
# Fallback: configs/skills/ relative to cwd (matches routes.skills).
import os
return os.path.join(os.getcwd(), "configs", "skills")
def _get_skill_registry(request: Request) -> Any:
"""Return the live :class:`SkillRegistry` from ``app.state``.
Raises HTTPException(500) if the registry is missing admin skill
endpoints cannot function without it.
"""
registry = getattr(request.app.state, "skill_registry", None)
if registry is None:
raise HTTPException(
status_code=500,
detail="Skill registry not initialized on app.state",
)
return registry
class SkillImportRequest(BaseModel):
"""Body for ``POST /admin/skills/import``."""
model_config = ConfigDict(extra="forbid")
yaml_content: str
class SkillUpdateRequest(BaseModel):
"""Body for ``PATCH /admin/skills/{name}``."""
model_config = ConfigDict(extra="forbid")
config: dict[str, Any]
@admin_router.post("/skills/{name}/enable")
async def enable_skill(
name: str,
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Enable a previously-disabled skill.
Returns 200 ``{enabled: true}`` on success. Idempotent returns
200 even if the skill was not disabled.
"""
db_path = await _ensure_db(request)
svc = get_skill_service()
try:
enabled = await svc.enable_skill(db_path, name)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return {"enabled": enabled, "skill_name": name}
@admin_router.post("/skills/{name}/disable")
async def disable_skill(
name: str,
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Disable a skill (hides it from ``GET /skills``).
Returns 200 with the disabled skill state dict.
"""
db_path = await _ensure_db(request)
svc = get_skill_service()
try:
return await svc.disable_skill(db_path, name, disabled_by=admin.get("user_id"))
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@admin_router.patch("/skills/{name}")
async def update_skill(
name: str,
payload: SkillUpdateRequest,
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Update a skill's YAML config in-place and reload it.
Returns 200 with the updated skill info. Returns 404 if the skill
YAML file does not exist, 400 if the patched config is invalid.
"""
skills_dir = _get_skills_dir(request)
registry = _get_skill_registry(request)
svc = get_skill_service()
try:
return await svc.update_skill_config(
name,
payload.config,
skills_dir,
registry,
)
except ValueError as exc:
msg = str(exc)
if "not found" in msg:
raise HTTPException(status_code=404, detail=msg) from exc
raise HTTPException(status_code=400, detail=msg) from exc
@admin_router.post("/skills/import")
async def import_skill(
payload: SkillImportRequest,
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Import a skill from YAML content.
Writes the YAML to ``{skills_dir}/{name}.yaml`` and registers it
in the live :class:`SkillRegistry`. Returns 200 with the imported
skill info. Returns 400 if the YAML is invalid or the skill name
is invalid.
"""
skills_dir = _get_skills_dir(request)
registry = _get_skill_registry(request)
svc = get_skill_service()
try:
return await svc.import_skill(
payload.yaml_content,
skills_dir,
skill_registry=registry,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@admin_router.post("/skills/{name}/reload")
async def reload_skill(
name: str,
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Reload a skill from its YAML file.
Returns 200 with the reloaded skill info. Returns 404 if the YAML
file does not exist, 400 if the skill name is invalid.
"""
skills_dir = _get_skills_dir(request)
registry = _get_skill_registry(request)
svc = get_skill_service()
try:
return await svc.reload_skill(name, registry, skills_dir)
except ValueError as exc:
msg = str(exc)
if "not found" in msg:
raise HTTPException(status_code=404, detail=msg) from exc
raise HTTPException(status_code=400, detail=msg) from exc
# ---------------------------------------------------------------------------
# KB management endpoints (U6) — documents + sources
# ---------------------------------------------------------------------------
class KbDocumentUploadRequest(BaseModel):
"""Body for ``POST /admin/kb/documents``."""
model_config = ConfigDict(extra="forbid")
filename: str
content: str
source_id: str = ""
department_id: str | None = None
@admin_router.get("/kb/documents")
async def list_kb_documents(
request: Request,
source_id: str | None = None,
department_id: str | None = None,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""List KB documents (admin sees all).
Query params: ``source_id`` (optional filter), ``department_id``
(optional filter when set, only documents bound to that
department or global documents are returned).
"""
svc = get_kb_service()
# Admin sees everything. If department_id is provided as a query
# filter, narrow to that department + global docs.
if department_id is not None:
dept_ids = [department_id]
else:
dept_ids = None # admin: no filtering
documents = svc.list_documents(department_ids=dept_ids, source_id=source_id)
return {"documents": documents}
@admin_router.post("/kb/documents", status_code=201)
async def upload_kb_document(
payload: KbDocumentUploadRequest,
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Upload a KB document with optional department binding.
Returns 201 with the uploaded document dict.
"""
svc = get_kb_service()
return svc.upload_document(
filename=payload.filename,
content=payload.content.encode("utf-8"),
source_id=payload.source_id,
department_id=payload.department_id,
)
@admin_router.delete("/kb/documents/{document_id}")
async def delete_kb_document(
document_id: str,
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Delete a KB document by id.
Returns 200 ``{deleted: true}`` on success, 404 if not found.
"""
svc = get_kb_service()
deleted = svc.delete_document(document_id)
if not deleted:
raise HTTPException(status_code=404, detail="Document not found")
return {"deleted": True}
@admin_router.post("/kb/sources/{source_id}/sync")
async def sync_kb_source(
source_id: str,
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Trigger a sync for a KB source.
Returns 200 with the sync status. Returns 404 if the source does
not exist.
"""
svc = get_kb_service()
try:
return svc.sync_source(source_id)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
@admin_router.post("/kb/sources/{source_id}/rebuild")
async def rebuild_kb_source(
source_id: str,
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Rebuild the index for a KB source.
Returns 200 with the rebuild status. Returns 404 if the source
does not exist.
"""
svc = get_kb_service()
try:
return svc.rebuild_index(source_id)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc

View File

@ -73,6 +73,7 @@ class KnowledgeSource:
status: str = "active" status: str = "active"
document_count: int = 0 document_count: int = 0
last_synced: str | None = None last_synced: str | None = None
department_id: str | None = None
@dataclass @dataclass
@ -83,6 +84,7 @@ class UploadedDocument:
chunks: int chunks: int
status: str status: str
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
department_id: str | None = None
class KnowledgeSourceStore: class KnowledgeSourceStore:
@ -155,6 +157,23 @@ class KnowledgeSourceStore:
_source_store = KnowledgeSourceStore() _source_store = KnowledgeSourceStore()
def get_source_store() -> KnowledgeSourceStore:
"""Return the process-wide :class:`KnowledgeSourceStore` singleton.
Exposed as a function (rather than importing ``_source_store``
directly) so tests can swap it out via monkeypatch, and so the
:class:`KbService` wrapper can access the store without reaching
into a private module attribute.
"""
return _source_store
def set_source_store(store: KnowledgeSourceStore) -> None:
"""Inject a custom :class:`KnowledgeSourceStore` (used by tests)."""
global _source_store
_source_store = store
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Request / Response models # Request / Response models
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@ -54,9 +54,7 @@ def _skill_to_info(skill: Any) -> dict[str, Any]:
if hasattr(skill, "config") and hasattr(skill.config, "capabilities"): if hasattr(skill, "config") and hasattr(skill.config, "capabilities"):
caps = skill.config.capabilities caps = skill.config.capabilities
if isinstance(caps, list): if isinstance(caps, list):
capabilities = [ capabilities = [c.tag if hasattr(c, "tag") else str(c) for c in caps]
c.tag if hasattr(c, "tag") else str(c) for c in caps
]
elif isinstance(caps, dict): elif isinstance(caps, dict):
capabilities = list(caps.keys()) capabilities = list(caps.keys())
@ -160,7 +158,7 @@ async def check_skill_health(skill_name: str, req: Request):
"""Check the health of a specific skill.""" """Check the health of a specific skill."""
skill_registry = req.app.state.skill_registry skill_registry = req.app.state.skill_registry
try: try:
skill = skill_registry.get(skill_name) skill_registry.get(skill_name)
except Exception: except Exception:
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found") raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
@ -213,17 +211,43 @@ async def list_capabilities(req: Request):
@router.post("/skill-management/skills/{skill_name}/reload") @router.post("/skill-management/skills/{skill_name}/reload")
async def reload_skill(skill_name: str, req: Request): async def reload_skill(skill_name: str, req: Request):
"""Reload a skill configuration.""" """Reload a skill configuration from its YAML file."""
skill_registry = req.app.state.skill_registry skill_registry = req.app.state.skill_registry
# Verify the skill is currently registered (404 if not).
try: try:
skill = skill_registry.get(skill_name) skill_registry.get(skill_name)
except Exception: except Exception:
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found") raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
# In a full implementation, this would reload the skill from its config source # Resolve the skills directory (mirrors routes.skills._get_skills_dir).
# For now, just return success import os
skills_dir: str
server_config = getattr(req.app.state, "server_config", None)
if server_config and getattr(server_config, "skill_paths", None):
from pathlib import Path as _P
first_path = _P(server_config.skill_paths[0])
if first_path.is_dir():
skills_dir = str(first_path)
else:
skills_dir = os.path.join(os.getcwd(), "configs", "skills")
else:
skills_dir = os.path.join(os.getcwd(), "configs", "skills")
from agentkit.server.admin.skill_service import get_skill_service
svc = get_skill_service()
try:
result = await svc.reload_skill(skill_name, skill_registry, skills_dir)
except ValueError as exc:
msg = str(exc)
if "not found" in msg:
raise HTTPException(status_code=404, detail=msg) from exc
raise HTTPException(status_code=400, detail=msg) from exc
return { return {
"skill_name": skill_name, "skill_name": result["name"],
"status": "reloaded", "status": result["status"],
"message": f"技能 '{skill_name}' 已重新加载", "message": f"技能 '{skill_name}' 已重新加载",
} }

View File

@ -49,6 +49,7 @@ def _get_skills_dir(req: Request) -> str:
if server_config and server_config.skill_paths: if server_config and server_config.skill_paths:
# Use the first configured skill path as the install target # Use the first configured skill path as the install target
from pathlib import Path as _P from pathlib import Path as _P
first_path = _P(server_config.skill_paths[0]) first_path = _P(server_config.skill_paths[0])
if first_path.is_dir(): if first_path.is_dir():
return str(first_path) return str(first_path)
@ -59,12 +60,16 @@ def _get_skills_dir(req: Request) -> str:
def _validate_source_url(source: str) -> None: def _validate_source_url(source: str) -> None:
"""Validate that a source URL points to an allowed domain (SSRF mitigation).""" """Validate that a source URL points to an allowed domain (SSRF mitigation)."""
from urllib.parse import urlparse from urllib.parse import urlparse
parsed = urlparse(source) parsed = urlparse(source)
if parsed.scheme not in ("https", "http"): if parsed.scheme not in ("https", "http"):
raise HTTPException(status_code=400, detail=f"Invalid source URL scheme: only http/https allowed") raise HTTPException(
status_code=400, detail="Invalid source URL scheme: only http/https allowed"
)
# Block private/internal IPs by checking hostname # Block private/internal IPs by checking hostname
import ipaddress import ipaddress
import socket import socket
hostname = parsed.hostname hostname = parsed.hostname
if hostname: if hostname:
try: try:
@ -82,12 +87,15 @@ def _validate_source_url(source: str) -> None:
# Check domain allowlist for source URLs # Check domain allowlist for source URLs
if hostname and hostname not in _ALLOWED_DOWNLOAD_DOMAINS: if hostname and hostname not in _ALLOWED_DOWNLOAD_DOMAINS:
# Allow but log a warning for non-allowlisted domains # Allow but log a warning for non-allowlisted domains
logger.warning(f"Source URL domain '{hostname}' is not in the allowlist: {_ALLOWED_DOWNLOAD_DOMAINS}") logger.warning(
f"Source URL domain '{hostname}' is not in the allowlist: {_ALLOWED_DOWNLOAD_DOMAINS}"
)
def _validate_yaml_content(content: str) -> dict: def _validate_yaml_content(content: str) -> dict:
"""Validate YAML content before writing to disk. Returns parsed dict.""" """Validate YAML content before writing to disk. Returns parsed dict."""
import yaml import yaml
try: try:
data = yaml.safe_load(content) data = yaml.safe_load(content)
except yaml.YAMLError as e: except yaml.YAMLError as e:
@ -156,16 +164,37 @@ async def list_skills(
Admin users (``role == "admin"``) bypass filtering and see all Admin users (``role == "admin"``) bypass filtering and see all
registered skills. Unauthenticated callers (API-key clients) see registered skills. Unauthenticated callers (API-key clients) see
only global skills. only global skills.
Disabled-skill filtering (U6): skills marked as disabled by an
admin via ``POST /admin/skills/{name}/disable`` are excluded from
the response for ALL callers (admins included). This keeps the
public skill list in sync with the admin enable/disable state.
""" """
skill_registry = req.app.state.skill_registry skill_registry = req.app.state.skill_registry
skills = skill_registry.list_skills() skills = skill_registry.list_skills()
# U6: filter out admin-disabled skills (applies to all callers).
db_path = _resolve_db_path(req)
if db_path.exists():
try:
from agentkit.server.admin.skill_service import get_skill_service
svc = get_skill_service()
disabled_names = set(await svc.list_disabled_skills(db_path))
except Exception: # noqa: BLE001 — never block listing on DB errors
logger.exception("Failed to load disabled skills list — skipping filter")
disabled_names = set()
else:
disabled_names = set()
if disabled_names:
skills = [s for s in skills if s.name not in disabled_names]
# Admins bypass department filtering. # Admins bypass department filtering.
if not dept_ctx.should_filter: if not dept_ctx.should_filter:
return _serialize_skills(skills) return _serialize_skills(skills)
# Non-admin: filter by department bindings. # Non-admin: filter by department bindings.
db_path = _resolve_db_path(req)
all_names = [s.name for s in skills] all_names = [s.name for s in skills]
try: try:
visible_names = await filter_skills_by_department( visible_names = await filter_skills_by_department(
@ -218,15 +247,13 @@ async def mention_suggest(q: str = "", req: Request = None):
if query: if query:
skills = [ skills = [
s for s in skills s
for s in skills
if query in s.name.lower() if query in s.name.lower()
or (s.config.description and query in s.config.description.lower()) or (s.config.description and query in s.config.description.lower())
] ]
return [ return [{"name": s.name, "description": s.config.description or ""} for s in skills[:8]]
{"name": s.name, "description": s.config.description or ""}
for s in skills[:8]
]
@router.post("/skills/install") @router.post("/skills/install")
@ -246,7 +273,9 @@ async def install_skill(request: InstallSkillRequest, req: Request):
if source and source.startswith("http"): if source and source.startswith("http"):
_validate_source_url(source) _validate_source_url(source)
try: try:
async with httpx.AsyncClient(timeout=30, follow_redirects=True, max_redirects=3) as client: async with httpx.AsyncClient(
timeout=30, follow_redirects=True, max_redirects=3
) as client:
resp = await client.get(source) resp = await client.get(source)
resp.raise_for_status() resp.raise_for_status()
yaml_content = resp.text yaml_content = resp.text
@ -260,7 +289,9 @@ async def install_skill(request: InstallSkillRequest, req: Request):
# Verify the path is within the skills directory # Verify the path is within the skills directory
skills_dir_base = _get_skills_dir(req) skills_dir_base = _get_skills_dir(req)
if not os.path.realpath(local_path).startswith(os.path.realpath(skills_dir_base)): if not os.path.realpath(local_path).startswith(os.path.realpath(skills_dir_base)):
raise HTTPException(status_code=400, detail="Local file path must be within the skills directory") raise HTTPException(
status_code=400, detail="Local file path must be within the skills directory"
)
try: try:
with open(local_path, encoding="utf-8") as f: with open(local_path, encoding="utf-8") as f:
yaml_content = f.read() yaml_content = f.read()
@ -290,12 +321,17 @@ async def install_skill(request: InstallSkillRequest, req: Request):
# Fallback: try a simpler search # Fallback: try a simpler search
search_query2 = f"{skill_name} skill" search_query2 = f"{skill_name} skill"
encoded_query2 = urllib.parse.quote(search_query2) encoded_query2 = urllib.parse.quote(search_query2)
github_api2 = f"https://api.github.com/search/code?q={encoded_query2}+extension:yaml&per_page=5" github_api2 = (
f"https://api.github.com/search/code?q={encoded_query2}+extension:yaml&per_page=5"
)
try: try:
async with httpx.AsyncClient(timeout=15) as client: async with httpx.AsyncClient(timeout=15) as client:
gh_resp2 = await client.get( gh_resp2 = await client.get(
github_api2, github_api2,
headers={"Accept": "application/vnd.github.v3+json", "User-Agent": "agentkit"}, headers={
"Accept": "application/vnd.github.v3+json",
"User-Agent": "agentkit",
},
) )
items = gh_resp2.json().get("items", []) items = gh_resp2.json().get("items", [])
except Exception: except Exception:
@ -310,13 +346,19 @@ async def install_skill(request: InstallSkillRequest, req: Request):
if raw_url: if raw_url:
# Validate the URL is from github.com before transforming # Validate the URL is from github.com before transforming
if not raw_url.startswith("https://github.com/"): if not raw_url.startswith("https://github.com/"):
raise HTTPException(status_code=400, detail="Search result URL is not from github.com") raise HTTPException(
raw_url = raw_url.replace("github.com", "raw.githubusercontent.com").replace("/blob/", "/") status_code=400, detail="Search result URL is not from github.com"
)
raw_url = raw_url.replace("github.com", "raw.githubusercontent.com").replace(
"/blob/", "/"
)
else: else:
raise HTTPException(status_code=404, detail="Could not construct download URL") raise HTTPException(status_code=404, detail="Could not construct download URL")
try: try:
async with httpx.AsyncClient(timeout=30, follow_redirects=True, max_redirects=3) as client: async with httpx.AsyncClient(
timeout=30, follow_redirects=True, max_redirects=3
) as client:
resp = await client.get(raw_url) resp = await client.get(raw_url)
resp.raise_for_status() resp.raise_for_status()
yaml_content = resp.text yaml_content = resp.text
@ -342,6 +384,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
registration_ok = False registration_ok = False
try: try:
from agentkit.skills.loader import SkillLoader from agentkit.skills.loader import SkillLoader
loader = SkillLoader( loader = SkillLoader(
skill_registry=skill_registry, skill_registry=skill_registry,
tool_registry=tool_registry, tool_registry=tool_registry,
@ -357,7 +400,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
os.remove(file_path) os.remove(file_path)
except Exception: except Exception:
pass pass
raise HTTPException(status_code=500, detail=f"Skill downloaded but registration failed") raise HTTPException(status_code=500, detail="Skill downloaded but registration failed")
return { return {
"status": "installed", "status": "installed",
@ -387,7 +430,9 @@ async def uninstall_skill(name: str, req: Request):
yaml_path = os.path.join(skills_dir, f"{validated_name}.yaml") yaml_path = os.path.join(skills_dir, f"{validated_name}.yaml")
# Verify resolved path stays within skills_dir # Verify resolved path stays within skills_dir
if os.path.exists(yaml_path) and os.path.realpath(yaml_path).startswith(os.path.realpath(skills_dir)): if os.path.exists(yaml_path) and os.path.realpath(yaml_path).startswith(
os.path.realpath(skills_dir)
):
os.remove(yaml_path) os.remove(yaml_path)
return {"status": "uninstalled", "name": validated_name} return {"status": "uninstalled", "name": validated_name}
@ -419,8 +464,7 @@ async def create_pipeline(request: CreatePipelineRequest, req: Request):
return { return {
"name": pipeline.name, "name": pipeline.name,
"steps": [ "steps": [
{"skill_name": s["skill_name"], "step_index": i} {"skill_name": s["skill_name"], "step_index": i} for i, s in enumerate(request.steps)
for i, s in enumerate(request.steps)
], ],
} }

View File

@ -0,0 +1,522 @@
"""Integration tests for the skill/KB admin routes (U6).
Uses FastAPI TestClient with a test app that mounts the
``admin_router`` from ``routes.admin`` plus the public ``skills``
router from ``routes.skills`` (to verify disabled-skill filtering).
The ``_require_admin`` dependency is overridden via
``app.dependency_overrides`` so the tests don't need real JWTs.
The :class:`SkillRegistry` is a real instance with a test skill
loaded from a temp YAML file, so import/reload/update endpoints
exercise the real SkillLoader pipeline.
"""
from __future__ import annotations
import os
import uuid
from pathlib import Path
from typing import Any
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from agentkit.server.admin.kb_service import set_kb_service
from agentkit.server.admin.skill_service import set_skill_service
from agentkit.server.auth.models import init_auth_db
from agentkit.server.routes import admin as admin_routes_module
from agentkit.server.routes import skills as skills_routes_module
from agentkit.skills.registry import SkillRegistry
# ---------------------------------------------------------------------------
# Test data
# ---------------------------------------------------------------------------
_VALID_SKILL_YAML = """\
name: admin_test_skill
agent_type: simple_generation
version: "1.0.0"
description: "A test skill for admin route testing"
task_mode: llm_generate
execution_mode: direct
max_steps: 1
prompt:
identity: "Test"
instructions: "Handle test"
tools: []
"""
_VALID_SKILL_YAML_2 = """\
name: another_test_skill
agent_type: simple_generation
version: "1.0.0"
description: "Another test skill"
task_mode: llm_generate
execution_mode: direct
max_steps: 1
prompt:
identity: "Test2"
instructions: "Handle test 2"
tools: []
"""
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
async def tmp_auth_db(tmp_path: Path) -> Path:
db_path = tmp_path / "admin_skill_kb.db"
await init_auth_db(db_path)
return db_path
@pytest.fixture
def skills_dir(tmp_path: Path) -> str:
"""A temp skills directory for YAML files."""
d = tmp_path / "skills"
d.mkdir()
return str(d)
@pytest.fixture
def skill_registry() -> SkillRegistry:
return SkillRegistry()
@pytest.fixture(autouse=True)
def _reset_singletons():
"""Reset SkillService and KbService singletons before/after each test."""
set_skill_service(None)
set_kb_service(None)
yield
set_skill_service(None)
set_kb_service(None)
def _make_admin_user() -> dict[str, Any]:
return {"user_id": "admin-1", "username": "admin", "role": "admin"}
def _raise_forbidden() -> dict[str, Any]:
"""Dependency override that simulates a non-admin (403) response."""
raise HTTPException(status_code=403, detail="Admin permission required")
@pytest.fixture
def admin_app(
tmp_auth_db: Path,
skills_dir: str,
skill_registry: SkillRegistry,
) -> FastAPI:
"""A minimal FastAPI app with admin + skills routers mounted.
The ``_require_admin`` dependency is overridden to return a fake
admin user. The :class:`SkillRegistry` is set on ``app.state`` so
the admin skill endpoints can access it.
"""
app = FastAPI()
app.state.auth_db_path = str(tmp_auth_db)
app.state.skill_registry = skill_registry
# Build a minimal server_config with skill_paths pointing at the
# temp skills_dir, so the admin endpoints write YAML there.
class _FakeServerConfig:
skill_paths = [skills_dir]
app.state.server_config = _FakeServerConfig()
app.include_router(admin_routes_module.admin_router, prefix="/api/v1")
app.include_router(skills_routes_module.router, prefix="/api/v1")
# Default: allow admin access.
app.dependency_overrides[admin_routes_module._require_admin] = lambda: _make_admin_user()
# Override the department context to admin (bypass filtering) so
# GET /skills returns all skills (the disabled filter still applies).
from agentkit.server.admin.context import DepartmentContext
app.dependency_overrides[skills_routes_module.get_department_context] = lambda: (
DepartmentContext(user_id="admin-1", department_ids=[], is_admin=True)
)
return app
@pytest.fixture
def admin_client(admin_app: FastAPI) -> TestClient:
return TestClient(admin_app)
def _write_skill_yaml(skills_dir: str, name: str, content: str) -> str:
"""Write a skill YAML file to the skills dir and return the path."""
path = os.path.join(skills_dir, f"{name}.yaml")
with open(path, "w", encoding="utf-8") as f:
f.write(content)
return path
# ---------------------------------------------------------------------------
# Skill enable/disable
# ---------------------------------------------------------------------------
class TestSkillEnableDisable:
def test_disable_skill_returns_200(self, admin_client: TestClient):
resp = admin_client.post("/api/v1/admin/skills/admin_test_skill/disable")
assert resp.status_code == 200, resp.text
body = resp.json()
assert body["skill_name"] == "admin_test_skill"
assert body["is_disabled"] is True
def test_disable_then_get_skills_excludes_it(
self,
admin_client: TestClient,
skills_dir: str,
skill_registry: SkillRegistry,
):
# Write a skill YAML and register it.
_write_skill_yaml(skills_dir, "admin_test_skill", _VALID_SKILL_YAML)
from agentkit.skills.loader import SkillLoader
SkillLoader(skill_registry=skill_registry).load_from_file(
os.path.join(skills_dir, "admin_test_skill.yaml")
)
# Verify the skill is initially listed.
resp = admin_client.get("/api/v1/skills")
assert resp.status_code == 200
names = [s["name"] for s in resp.json()]
assert "admin_test_skill" in names
# Disable the skill.
resp = admin_client.post("/api/v1/admin/skills/admin_test_skill/disable")
assert resp.status_code == 200
# GET /skills should now exclude it.
resp = admin_client.get("/api/v1/skills")
assert resp.status_code == 200
names = [s["name"] for s in resp.json()]
assert "admin_test_skill" not in names
def test_enable_skill_returns_200(
self,
admin_client: TestClient,
skills_dir: str,
skill_registry: SkillRegistry,
):
# Write and register the skill, then disable it.
_write_skill_yaml(skills_dir, "admin_test_skill", _VALID_SKILL_YAML)
from agentkit.skills.loader import SkillLoader
SkillLoader(skill_registry=skill_registry).load_from_file(
os.path.join(skills_dir, "admin_test_skill.yaml")
)
admin_client.post("/api/v1/admin/skills/admin_test_skill/disable")
# Enable it.
resp = admin_client.post("/api/v1/admin/skills/admin_test_skill/enable")
assert resp.status_code == 200
assert resp.json()["enabled"] is True
# GET /skills should now include it.
resp = admin_client.get("/api/v1/skills")
names = [s["name"] for s in resp.json()]
assert "admin_test_skill" in names
def test_enable_skill_not_disabled_returns_200(self, admin_client: TestClient):
"""Enabling a skill that wasn't disabled returns enabled=False."""
resp = admin_client.post("/api/v1/admin/skills/never_disabled/enable")
assert resp.status_code == 200
assert resp.json()["enabled"] is False
def test_disable_skill_invalid_name_returns_400(self, admin_client: TestClient):
resp = admin_client.post("/api/v1/admin/skills/Has Spaces/disable")
assert resp.status_code == 400
def test_non_admin_cannot_disable_skill(self, admin_app: FastAPI):
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
client = TestClient(admin_app)
resp = client.post("/api/v1/admin/skills/some_skill/disable")
assert resp.status_code == 403
# ---------------------------------------------------------------------------
# Skill import
# ---------------------------------------------------------------------------
class TestSkillImport:
def test_import_valid_yaml_returns_200(
self,
admin_client: TestClient,
skills_dir: str,
skill_registry: SkillRegistry,
):
resp = admin_client.post(
"/api/v1/admin/skills/import",
json={"yaml_content": _VALID_SKILL_YAML},
)
assert resp.status_code == 200, resp.text
body = resp.json()
assert body["name"] == "admin_test_skill"
assert os.path.isfile(body["path"])
# Skill should be registered.
assert skill_registry.has_skill("admin_test_skill")
def test_import_invalid_yaml_returns_400(self, admin_client: TestClient):
resp = admin_client.post(
"/api/v1/admin/skills/import",
json={"yaml_content": "not: valid: yaml: ["},
)
assert resp.status_code == 400
def test_import_missing_name_returns_400(self, admin_client: TestClient):
resp = admin_client.post(
"/api/v1/admin/skills/import",
json={"yaml_content": "agent_type: test\n"},
)
assert resp.status_code == 400
def test_import_invalid_skill_name_returns_400(self, admin_client: TestClient):
bad_yaml = 'name: "Has Spaces"\nagent_type: test\n'
resp = admin_client.post(
"/api/v1/admin/skills/import",
json={"yaml_content": bad_yaml},
)
assert resp.status_code == 400
def test_non_admin_cannot_import(self, admin_app: FastAPI):
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
client = TestClient(admin_app)
resp = client.post(
"/api/v1/admin/skills/import",
json={"yaml_content": _VALID_SKILL_YAML},
)
assert resp.status_code == 403
# ---------------------------------------------------------------------------
# Skill reload
# ---------------------------------------------------------------------------
class TestSkillReload:
def test_reload_existing_skill_returns_200(
self,
admin_client: TestClient,
skills_dir: str,
skill_registry: SkillRegistry,
):
# Write and register the skill first.
_write_skill_yaml(skills_dir, "admin_test_skill", _VALID_SKILL_YAML)
from agentkit.skills.loader import SkillLoader
SkillLoader(skill_registry=skill_registry).load_from_file(
os.path.join(skills_dir, "admin_test_skill.yaml")
)
resp = admin_client.post("/api/v1/admin/skills/admin_test_skill/reload")
assert resp.status_code == 200, resp.text
body = resp.json()
assert body["name"] == "admin_test_skill"
assert body["status"] == "reloaded"
def test_reload_nonexistent_skill_returns_404(self, admin_client: TestClient):
resp = admin_client.post("/api/v1/admin/skills/nonexistent/reload")
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# Skill update (PATCH)
# ---------------------------------------------------------------------------
class TestSkillUpdate:
def test_update_skill_returns_200(
self,
admin_client: TestClient,
skills_dir: str,
skill_registry: SkillRegistry,
):
# Write and register the skill first.
_write_skill_yaml(skills_dir, "admin_test_skill", _VALID_SKILL_YAML)
from agentkit.skills.loader import SkillLoader
SkillLoader(skill_registry=skill_registry).load_from_file(
os.path.join(skills_dir, "admin_test_skill.yaml")
)
resp = admin_client.patch(
"/api/v1/admin/skills/admin_test_skill",
json={"config": {"description": "Updated via admin API"}},
)
assert resp.status_code == 200, resp.text
body = resp.json()
assert body["status"] == "updated"
# Verify the YAML file was updated.
import yaml
with open(body["path"], encoding="utf-8") as f:
data = yaml.safe_load(f)
assert data["description"] == "Updated via admin API"
def test_update_nonexistent_skill_returns_404(self, admin_client: TestClient):
resp = admin_client.patch(
"/api/v1/admin/skills/nonexistent",
json={"config": {"description": "x"}},
)
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# KB document endpoints
# ---------------------------------------------------------------------------
class TestKbDocumentRoutes:
def test_list_documents_returns_200(self, admin_client: TestClient):
resp = admin_client.get("/api/v1/admin/kb/documents")
assert resp.status_code == 200, resp.text
body = resp.json()
assert "documents" in body
assert isinstance(body["documents"], list)
def test_upload_document_returns_201_with_department_id(self, admin_client: TestClient):
dept_id = str(uuid.uuid4())
resp = admin_client.post(
"/api/v1/admin/kb/documents",
json={
"filename": "test.txt",
"content": "hello world",
"department_id": dept_id,
},
)
assert resp.status_code == 201, resp.text
body = resp.json()
assert body["filename"] == "test.txt"
assert body["department_id"] == dept_id
assert "document_id" in body
def test_upload_document_without_department_id(self, admin_client: TestClient):
resp = admin_client.post(
"/api/v1/admin/kb/documents",
json={"filename": "global.txt", "content": "global content"},
)
assert resp.status_code == 201
body = resp.json()
assert body["department_id"] is None
def test_upload_then_delete_document(self, admin_client: TestClient):
# Upload
resp = admin_client.post(
"/api/v1/admin/kb/documents",
json={"filename": "delete_me.txt", "content": "x"},
)
assert resp.status_code == 201
doc_id = resp.json()["document_id"]
# Delete
resp = admin_client.delete(f"/api/v1/admin/kb/documents/{doc_id}")
assert resp.status_code == 200
assert resp.json() == {"deleted": True}
# Second delete should 404.
resp = admin_client.delete(f"/api/v1/admin/kb/documents/{doc_id}")
assert resp.status_code == 404
def test_delete_nonexistent_document_returns_404(self, admin_client: TestClient):
resp = admin_client.delete("/api/v1/admin/kb/documents/nonexistent-id")
assert resp.status_code == 404
def test_list_documents_filters_by_department_id(self, admin_client: TestClient):
dept_a = str(uuid.uuid4())
dept_b = str(uuid.uuid4())
# Upload 3 docs: one for dept_a, one for dept_b, one global.
admin_client.post(
"/api/v1/admin/kb/documents",
json={"filename": "a.txt", "content": "a", "department_id": dept_a},
)
admin_client.post(
"/api/v1/admin/kb/documents",
json={"filename": "b.txt", "content": "b", "department_id": dept_b},
)
admin_client.post(
"/api/v1/admin/kb/documents",
json={"filename": "global.txt", "content": "g"},
)
# Filter by dept_a: should see dept_a's doc + global doc.
resp = admin_client.get("/api/v1/admin/kb/documents", params={"department_id": dept_a})
assert resp.status_code == 200
docs = resp.json()["documents"]
dept_ids = {d["department_id"] for d in docs}
assert dept_a in dept_ids
assert None in dept_ids # global doc
assert dept_b not in dept_ids
def test_non_admin_cannot_list_documents(self, admin_app: FastAPI):
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
client = TestClient(admin_app)
resp = client.get("/api/v1/admin/kb/documents")
assert resp.status_code == 403
def test_non_admin_cannot_upload_document(self, admin_app: FastAPI):
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
client = TestClient(admin_app)
resp = client.post(
"/api/v1/admin/kb/documents",
json={"filename": "x.txt", "content": "x"},
)
assert resp.status_code == 403
# ---------------------------------------------------------------------------
# KB source sync/rebuild
# ---------------------------------------------------------------------------
class TestKbSourceRoutes:
def test_sync_nonexistent_source_returns_404(self, admin_client: TestClient):
resp = admin_client.post("/api/v1/admin/kb/sources/nonexistent/sync")
assert resp.status_code == 404
def test_rebuild_nonexistent_source_returns_404(self, admin_client: TestClient):
resp = admin_client.post("/api/v1/admin/kb/sources/nonexistent/rebuild")
assert resp.status_code == 404
def test_sync_existing_source_returns_200(self, admin_client: TestClient):
# Add a source via the KbService (bypassing the route layer).
from agentkit.server.admin.kb_service import get_kb_service
svc = get_kb_service()
store = svc._resolve_store()
source = store.add_source("Test Source", "local", {})
resp = admin_client.post(f"/api/v1/admin/kb/sources/{source.id}/sync")
assert resp.status_code == 200, resp.text
body = resp.json()
assert body["status"] == "syncing"
def test_rebuild_existing_source_returns_200(self, admin_client: TestClient):
from agentkit.server.admin.kb_service import get_kb_service
svc = get_kb_service()
store = svc._resolve_store()
source = store.add_source("Test Source", "local", {})
resp = admin_client.post(f"/api/v1/admin/kb/sources/{source.id}/rebuild")
assert resp.status_code == 200, resp.text
body = resp.json()
assert body["status"] == "rebuilding"
def test_non_admin_cannot_sync_source(self, admin_app: FastAPI):
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
client = TestClient(admin_app)
resp = client.post("/api/v1/admin/kb/sources/any/sync")
assert resp.status_code == 403

View File

@ -4,7 +4,7 @@ Covers:
- ``init_auth_db`` creates the new V3 tables (departments, user_departments, - ``init_auth_db`` creates the new V3 tables (departments, user_departments,
department_skill_bindings, department_kb_bindings, department_quotas) department_skill_bindings, department_kb_bindings, department_quotas)
- ``init_auth_db`` is idempotent (calling twice does not error) - ``init_auth_db`` is idempotent (calling twice does not error)
- ``_SCHEMA_VERSION`` is recorded as 3 in ``auth_meta`` - ``_SCHEMA_VERSION`` is recorded as 4 in ``auth_meta`` (V4 adds skill_states)
- ``departments`` insert + query round-trip - ``departments`` insert + query round-trip
- ``user_departments`` many-to-many relationship (one user many departments, - ``user_departments`` many-to-many relationship (one user many departments,
one department many users) one department many users)
@ -109,9 +109,7 @@ async def _list_index_names(db: aiosqlite.Connection, table: str) -> set[str]:
async def _list_table_names(db: aiosqlite.Connection) -> set[str]: async def _list_table_names(db: aiosqlite.Connection) -> set[str]:
"""Return the set of table names in the SQLite file.""" """Return the set of table names in the SQLite file."""
cursor = await db.execute( cursor = await db.execute("SELECT name FROM sqlite_master WHERE type='table'")
"SELECT name FROM sqlite_master WHERE type='table'"
)
rows = await cursor.fetchall() rows = await cursor.fetchall()
return {row[0] for row in rows} return {row[0] for row in rows}
@ -122,9 +120,9 @@ async def _list_table_names(db: aiosqlite.Connection) -> set[str]:
class TestSchemaVersion: class TestSchemaVersion:
def test_schema_version_is_v3(self): def test_schema_version_is_v4(self):
"""V3 adds the department-scoped admin tables.""" """V4 adds the skill_states table (U6 — Admin Console)."""
assert _SCHEMA_VERSION == 3 assert _SCHEMA_VERSION == 4
def test_sqlalchemy_model_table_names(self): def test_sqlalchemy_model_table_names(self):
assert DepartmentModel.__tablename__ == "departments" assert DepartmentModel.__tablename__ == "departments"
@ -165,13 +163,13 @@ class TestInitAuthDbTables:
tables = await _list_table_names(db) tables = await _list_table_names(db)
assert "department_quotas" in tables assert "department_quotas" in tables
async def test_records_schema_version_3_in_auth_meta(self, fresh_db: Path): async def test_records_schema_version_4_in_auth_meta(self, fresh_db: Path):
async with aiosqlite.connect(str(fresh_db)) as db: async with aiosqlite.connect(str(fresh_db)) as db:
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT value FROM auth_meta WHERE key='schema_version'") cursor = await db.execute("SELECT value FROM auth_meta WHERE key='schema_version'")
row = await cursor.fetchone() row = await cursor.fetchone()
assert row is not None assert row is not None
assert row["value"] == "3" assert row["value"] == "4"
assert row["value"] == str(_SCHEMA_VERSION) assert row["value"] == str(_SCHEMA_VERSION)
async def test_init_auth_db_is_idempotent(self, tmp_path: Path): async def test_init_auth_db_is_idempotent(self, tmp_path: Path):
@ -279,9 +277,7 @@ class TestDepartmentsCrud:
) )
await db.commit() await db.commit()
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute( cursor = await db.execute("SELECT description FROM departments WHERE id=?", (dept_id,))
"SELECT description FROM departments WHERE id=?", (dept_id,)
)
row = await cursor.fetchone() row = await cursor.fetchone()
assert row is not None assert row is not None
assert row["description"] is None assert row["description"] is None
@ -310,8 +306,7 @@ class TestUserDepartmentsManyToMany:
await db.commit() await db.commit()
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute( cursor = await db.execute(
"SELECT department_id FROM user_departments WHERE user_id=? " "SELECT department_id FROM user_departments WHERE user_id=? ORDER BY department_id",
"ORDER BY department_id",
(user_id,), (user_id,),
) )
rows = await cursor.fetchall() rows = await cursor.fetchall()
@ -336,8 +331,7 @@ class TestUserDepartmentsManyToMany:
await db.commit() await db.commit()
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute( cursor = await db.execute(
"SELECT user_id FROM user_departments WHERE department_id=? " "SELECT user_id FROM user_departments WHERE department_id=? ORDER BY user_id",
"ORDER BY user_id",
(dept_id,), (dept_id,),
) )
rows = await cursor.fetchall() rows = await cursor.fetchall()
@ -394,9 +388,7 @@ class TestDepartmentSkillBindingsUnique:
) )
await db.commit() await db.commit()
async def test_same_skill_name_in_different_departments_is_allowed( async def test_same_skill_name_in_different_departments_is_allowed(self, fresh_db: Path):
self, fresh_db: Path
):
dept_a = str(uuid.uuid4()) dept_a = str(uuid.uuid4())
dept_b = str(uuid.uuid4()) dept_b = str(uuid.uuid4())
async with aiosqlite.connect(str(fresh_db)) as db: async with aiosqlite.connect(str(fresh_db)) as db:

View File

@ -0,0 +1,386 @@
"""Unit tests for SkillService (U6 — skill enable/disable + import/reload).
Covers:
- disable_skill is_skill_disabled returns True
- enable_skill is_skill_disabled returns False
- list_disabled_skills returns correct list
- import_skill with valid YAML file written, skill loaded
- import_skill with invalid YAML ValueError
- import_skill with path traversal attempt ValueError
- reload_skill unregister + reload
- update_skill_config YAML updated + reloaded
"""
from __future__ import annotations
from pathlib import Path
import pytest
from agentkit.server.admin.skill_service import SkillService, _validate_skill_name
from agentkit.server.auth.models import init_auth_db
from agentkit.skills.registry import SkillRegistry
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
async def fresh_db(tmp_path: Path) -> Path:
"""A brand-new auth DB on a fresh path (no data)."""
db_path = tmp_path / "auth.db"
await init_auth_db(db_path)
return db_path
@pytest.fixture
def service() -> SkillService:
return SkillService()
@pytest.fixture
def skills_dir(tmp_path: Path) -> str:
"""A temp skills directory for YAML files."""
d = tmp_path / "skills"
d.mkdir()
return str(d)
@pytest.fixture
def skill_registry() -> SkillRegistry:
return SkillRegistry()
_VALID_SKILL_YAML = """\
name: test_skill
agent_type: simple_generation
version: "1.0.0"
description: "A test skill for unit testing"
task_mode: llm_generate
execution_mode: direct
max_steps: 1
prompt:
identity: "Test"
instructions: "Handle test"
tools: []
"""
# ---------------------------------------------------------------------------
# Enable / disable
# ---------------------------------------------------------------------------
class TestDisableEnable:
async def test_disable_skill_marks_disabled(self, service: SkillService, fresh_db: Path):
result = await service.disable_skill(fresh_db, "test_skill", disabled_by="admin-1")
assert result["skill_name"] == "test_skill"
assert result["is_disabled"] is True
assert result["disabled_by"] == "admin-1"
assert "disabled_at" in result
assert await service.is_skill_disabled(fresh_db, "test_skill") is True
async def test_disable_skill_normalizes_name(self, service: SkillService, fresh_db: Path):
"""Skill names are normalized to lowercase before storage."""
await service.disable_skill(fresh_db, "TestSkill")
assert await service.is_skill_disabled(fresh_db, "testskill") is True
assert await service.is_skill_disabled(fresh_db, "TestSkill") is True
async def test_disable_skill_is_idempotent(self, service: SkillService, fresh_db: Path):
"""Disabling an already-disabled skill updates the row, not duplicates."""
await service.disable_skill(fresh_db, "skill_a", disabled_by="admin-1")
await service.disable_skill(fresh_db, "skill_a", disabled_by="admin-2")
disabled = await service.list_disabled_skills(fresh_db)
assert disabled == ["skill_a"]
async def test_enable_skill_removes_disabled_mark(self, service: SkillService, fresh_db: Path):
await service.disable_skill(fresh_db, "skill_b")
assert await service.is_skill_disabled(fresh_db, "skill_b") is True
enabled = await service.enable_skill(fresh_db, "skill_b")
assert enabled is True
assert await service.is_skill_disabled(fresh_db, "skill_b") is False
async def test_enable_skill_returns_false_if_not_disabled(
self, service: SkillService, fresh_db: Path
):
"""Enabling a skill that wasn't disabled returns False (no-op)."""
enabled = await service.enable_skill(fresh_db, "never_disabled")
assert enabled is False
async def test_is_skill_disabled_returns_false_for_unknown(
self, service: SkillService, fresh_db: Path
):
assert await service.is_skill_disabled(fresh_db, "unknown_skill") is False
async def test_list_disabled_skills_returns_sorted_list(
self, service: SkillService, fresh_db: Path
):
await service.disable_skill(fresh_db, "charlie")
await service.disable_skill(fresh_db, "alpha")
await service.disable_skill(fresh_db, "bravo")
result = await service.list_disabled_skills(fresh_db)
assert result == ["alpha", "bravo", "charlie"]
async def test_list_disabled_skills_empty_when_none_disabled(
self, service: SkillService, fresh_db: Path
):
result = await service.list_disabled_skills(fresh_db)
assert result == []
async def test_disable_skill_invalid_name_raises(self, service: SkillService, fresh_db: Path):
with pytest.raises(ValueError, match="Invalid skill name"):
await service.disable_skill(fresh_db, "Invalid Name With Spaces")
async def test_disable_skill_uppercase_name_normalizes(
self, service: SkillService, fresh_db: Path
):
"""Uppercase names are normalized to lowercase (regex requires lowercase)."""
# 'TEST_SKILL' normalizes to 'test_skill' which matches the regex.
await service.disable_skill(fresh_db, "TEST_SKILL")
assert await service.is_skill_disabled(fresh_db, "test_skill") is True
# ---------------------------------------------------------------------------
# import_skill
# ---------------------------------------------------------------------------
class TestImportSkill:
async def test_import_skill_writes_file_and_loads(
self,
service: SkillService,
skills_dir: str,
skill_registry: SkillRegistry,
):
result = await service.import_skill(
_VALID_SKILL_YAML,
skills_dir,
skill_registry=skill_registry,
)
assert result["name"] == "test_skill"
assert Path(result["path"]).is_file()
# Skill should be registered in the registry.
assert skill_registry.has_skill("test_skill")
async def test_import_skill_without_registry_writes_file_only(
self,
service: SkillService,
skills_dir: str,
):
result = await service.import_skill(_VALID_SKILL_YAML, skills_dir)
assert result["name"] == "test_skill"
assert Path(result["path"]).is_file()
async def test_import_skill_invalid_yaml_raises(
self,
service: SkillService,
skills_dir: str,
):
with pytest.raises(ValueError, match="Invalid YAML"):
await service.import_skill("not: valid: yaml: [", skills_dir)
async def test_import_skill_non_mapping_yaml_raises(
self,
service: SkillService,
skills_dir: str,
):
with pytest.raises(ValueError, match="mapping"):
await service.import_skill("- just\n- a\n- list\n", skills_dir)
async def test_import_skill_missing_name_field_raises(
self,
service: SkillService,
skills_dir: str,
):
yaml_without_name = "agent_type: simple_generation\n"
with pytest.raises(ValueError, match="name"):
await service.import_skill(yaml_without_name, skills_dir)
async def test_import_skill_invalid_name_raises(
self,
service: SkillService,
skills_dir: str,
):
"""YAML with a name that fails the regex raises ValueError."""
bad_yaml = 'name: "Bad Name With Spaces"\nagent_type: test\n'
with pytest.raises(ValueError, match="Invalid skill name"):
await service.import_skill(bad_yaml, skills_dir)
async def test_import_skill_path_traversal_blocked(
self,
service: SkillService,
skills_dir: str,
):
"""A YAML with a name containing path separators is rejected by the regex."""
# The regex `^[a-z0-9][a-z0-9_-]{0,63}$` rejects '/' and '..',
# so path traversal via the name field is impossible. We verify
# this by attempting to import a YAML with a traversal name.
traversal_yaml = 'name: "../etc/passwd"\nagent_type: test\n'
with pytest.raises(ValueError, match="Invalid skill name"):
await service.import_skill(traversal_yaml, skills_dir)
# ---------------------------------------------------------------------------
# reload_skill
# ---------------------------------------------------------------------------
class TestReloadSkill:
async def test_reload_skill_unregisters_and_reloads(
self,
service: SkillService,
skills_dir: str,
skill_registry: SkillRegistry,
):
# First import the skill.
await service.import_skill(
_VALID_SKILL_YAML,
skills_dir,
skill_registry=skill_registry,
)
assert skill_registry.has_skill("test_skill")
# Now reload it.
result = await service.reload_skill("test_skill", skill_registry, skills_dir)
assert result["name"] == "test_skill"
assert result["status"] == "reloaded"
assert skill_registry.has_skill("test_skill")
async def test_reload_skill_missing_yaml_raises(
self,
service: SkillService,
skills_dir: str,
skill_registry: SkillRegistry,
):
with pytest.raises(ValueError, match="not found"):
await service.reload_skill("nonexistent", skill_registry, skills_dir)
async def test_reload_skill_invalid_name_raises(
self,
service: SkillService,
skills_dir: str,
skill_registry: SkillRegistry,
):
with pytest.raises(ValueError, match="Invalid skill name"):
await service.reload_skill("Bad Name", skill_registry, skills_dir)
# ---------------------------------------------------------------------------
# update_skill_config
# ---------------------------------------------------------------------------
class TestUpdateSkillConfig:
async def test_update_skill_config_updates_yaml_and_reloads(
self,
service: SkillService,
skills_dir: str,
skill_registry: SkillRegistry,
):
# First import the skill.
await service.import_skill(
_VALID_SKILL_YAML,
skills_dir,
skill_registry=skill_registry,
)
# Patch the description.
result = await service.update_skill_config(
"test_skill",
{"description": "Updated description"},
skills_dir,
skill_registry,
)
assert result["name"] == "test_skill"
assert result["status"] == "updated"
# Verify the YAML file was updated.
import yaml
with open(result["path"], encoding="utf-8") as f:
data = yaml.safe_load(f)
assert data["description"] == "Updated description"
# Original fields should be preserved.
assert data["name"] == "test_skill"
assert data["agent_type"] == "simple_generation"
# Skill should still be registered.
assert skill_registry.has_skill("test_skill")
async def test_update_skill_config_missing_yaml_raises(
self,
service: SkillService,
skills_dir: str,
skill_registry: SkillRegistry,
):
with pytest.raises(ValueError, match="not found"):
await service.update_skill_config(
"nonexistent",
{"description": "x"},
skills_dir,
skill_registry,
)
async def test_update_skill_config_preserves_name(
self,
service: SkillService,
skills_dir: str,
skill_registry: SkillRegistry,
):
"""Patching the 'name' field is ignored — name is preserved."""
await service.import_skill(
_VALID_SKILL_YAML,
skills_dir,
skill_registry=skill_registry,
)
result = await service.update_skill_config(
"test_skill",
{"name": "should_be_ignored", "description": "new"},
skills_dir,
skill_registry,
)
import yaml
with open(result["path"], encoding="utf-8") as f:
data = yaml.safe_load(f)
# Name should be preserved (not renamed).
assert data["name"] == "test_skill"
assert data["description"] == "new"
# ---------------------------------------------------------------------------
# _validate_skill_name helper
# ---------------------------------------------------------------------------
class TestValidateSkillName:
def test_valid_name_returns_normalized(self):
assert _validate_skill_name("test_skill") == "test_skill"
assert _validate_skill_name("TestSkill") == "testskill"
assert _validate_skill_name(" spaced ") == "spaced"
assert _validate_skill_name("a-b_c-1") == "a-b_c-1"
def test_invalid_name_raises(self):
with pytest.raises(ValueError):
_validate_skill_name("Has Spaces")
with pytest.raises(ValueError):
_validate_skill_name("")
with pytest.raises(ValueError):
_validate_skill_name("-leading-dash")
with pytest.raises(ValueError):
_validate_skill_name("../traversal")
with pytest.raises(ValueError):
_validate_skill_name("has.dot")
with pytest.raises(ValueError):
_validate_skill_name("has/slash")
def test_non_string_raises(self):
with pytest.raises(ValueError):
_validate_skill_name(None) # type: ignore[arg-type]
with pytest.raises(ValueError):
_validate_skill_name(123) # type: ignore[arg-type]

View File

@ -118,9 +118,9 @@ async def _list_index_names(db: aiosqlite.Connection, table: str) -> set[str]:
class TestSchemaVersion: class TestSchemaVersion:
def test_schema_version_is_v3(self): def test_schema_version_is_v4(self):
"""The current schema version is 3 (V3 adds department-scoped admin tables).""" """The current schema version is 4 (V4 adds skill_states table)."""
assert _SCHEMA_VERSION == 3 assert _SCHEMA_VERSION == 4
def test_sqlalchemy_model_table_name(self): def test_sqlalchemy_model_table_name(self):
assert AuthSessionModel.__tablename__ == "auth_sessions" assert AuthSessionModel.__tablename__ == "auth_sessions"

View File

@ -2,13 +2,12 @@
from __future__ import annotations from __future__ import annotations
from unittest.mock import AsyncMock
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from agentkit.llm.gateway import LLMGateway from agentkit.llm.gateway import LLMGateway
from agentkit.server.app import create_app from agentkit.server.app import create_app
from agentkit.server.config import ServerConfig
from agentkit.skills.base import Skill, SkillConfig from agentkit.skills.base import Skill, SkillConfig
from agentkit.skills.registry import SkillRegistry from agentkit.skills.registry import SkillRegistry
from agentkit.tools.registry import ToolRegistry from agentkit.tools.registry import ToolRegistry
@ -202,7 +201,31 @@ class TestListCapabilities:
class TestReloadSkill: class TestReloadSkill:
def test_reload_skill(self, client, skill_registry): def test_reload_skill(self, client, skill_registry, tmp_path):
# Write a valid YAML file for "reload_skill" to a temp skills dir.
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
yaml_content = (
"name: reload_skill\n"
'agent_type: "test_type"\n'
'version: "1.0.0"\n'
'description: "Skill for reload testing"\n'
"task_mode: llm_generate\n"
"execution_mode: direct\n"
"max_steps: 1\n"
"intent:\n"
' keywords: ["reload"]\n'
' description: "reload test"\n'
"prompt:\n"
' identity: "reload tester"\n'
' instructions: "handle reload"\n'
"tools: []\n"
)
(skills_dir / "reload_skill.yaml").write_text(yaml_content, encoding="utf-8")
# Point the app at the temp skills dir and register the skill so the
# 404 guard in the route passes.
client.app.state.server_config = ServerConfig(skill_paths=[str(skills_dir)])
_register_skill(skill_registry, "reload_skill") _register_skill(skill_registry, "reload_skill")
response = client.post("/api/v1/skill-management/skills/reload_skill/reload") response = client.post("/api/v1/skill-management/skills/reload_skill/reload")