feat(admin): U6 — Skill & KB management endpoints + department binding
SkillService: enable/disable (persisted in skill_states table, schema v4), import from YAML (with path traversal + name validation), reload from file, update config. GET /skills now filters disabled skills. KbService: list/upload/delete documents with department_id binding. Added department_id field to KnowledgeSource + UploadedDocument. Department visibility: (bound to user depts) ∪ (global = None). 10 new admin endpoints: skill enable/disable/import/reload/update, KB documents CRUD, source sync/rebuild. All guarded by _require_admin. Implemented reload stub in skill_management.py (was no-op). 54 new tests (26 unit + 28 integration). Fixed 4 pre-existing lint errors. 357 admin tests pass, no regressions.
This commit is contained in:
parent
980919fc95
commit
fd7f6816b8
|
|
@ -0,0 +1,246 @@
|
|||
"""KbService — admin-driven KB document/source management (U6).
|
||||
|
||||
This module wraps the existing in-memory :class:`KnowledgeSourceStore`
|
||||
from :mod:`agentkit.server.routes.kb_management` so admin routes can
|
||||
manage documents and sources through a service layer (mirroring
|
||||
:class:`DepartmentService` and :class:`SkillService`).
|
||||
|
||||
Department filtering
|
||||
--------------------
|
||||
When listing documents, the service filters by department visibility:
|
||||
- documents whose ``department_id`` is in the caller's department set, OR
|
||||
- documents whose ``department_id`` is ``None`` (global documents).
|
||||
|
||||
Admin callers (``department_ids=None``) bypass filtering and see all
|
||||
documents.
|
||||
|
||||
The service is a module-level singleton (see :func:`get_kb_service`)
|
||||
so tests can inject a custom instance via :func:`set_kb_service`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from agentkit.server.routes.kb_management import (
|
||||
KnowledgeSource,
|
||||
KnowledgeSourceStore,
|
||||
UploadedDocument,
|
||||
get_source_store,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
"""Return current UTC time as ISO 8601 string."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
class KbService:
|
||||
"""Wraps :class:`KnowledgeSourceStore` with admin-friendly operations.
|
||||
|
||||
The service does NOT own the store — it delegates to the
|
||||
module-level singleton from :mod:`kb_management`. This keeps a
|
||||
single source of truth for KB state (the existing routes still
|
||||
work) while giving admin routes a clean service boundary.
|
||||
|
||||
Department filtering is applied at the service layer so admin
|
||||
routes can pass ``department_ids`` directly from the
|
||||
:class:`DepartmentContext` dependency.
|
||||
"""
|
||||
|
||||
def __init__(self, store: KnowledgeSourceStore | None = None) -> None:
|
||||
self._store = store
|
||||
|
||||
def _resolve_store(self) -> KnowledgeSourceStore:
|
||||
"""Return the injected store, or the module singleton."""
|
||||
if self._store is not None:
|
||||
return self._store
|
||||
return get_source_store()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Document operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def list_documents(
|
||||
self,
|
||||
department_ids: list[str] | None = None,
|
||||
source_id: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List documents, optionally filtered by source and department.
|
||||
|
||||
Args:
|
||||
department_ids: Caller's department ids. ``None`` means admin
|
||||
(no filtering — return all documents). An empty list means
|
||||
a user with no departments (return only global documents).
|
||||
source_id: Optional source id filter.
|
||||
|
||||
Returns:
|
||||
List of document dicts (JSON-safe).
|
||||
"""
|
||||
store = self._resolve_store()
|
||||
documents = store.list_documents(source_id=source_id)
|
||||
|
||||
# Admin bypass: department_ids=None means "see everything".
|
||||
if department_ids is None:
|
||||
return [self._doc_to_dict(d) for d in documents]
|
||||
|
||||
# Non-admin: visible = (docs in user's departments) ∪ (global docs).
|
||||
dept_set = set(department_ids)
|
||||
visible = [d for d in documents if d.department_id is None or d.department_id in dept_set]
|
||||
return [self._doc_to_dict(d) for d in visible]
|
||||
|
||||
def upload_document(
|
||||
self,
|
||||
filename: str,
|
||||
content: bytes,
|
||||
source_id: str = "",
|
||||
department_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Upload a document with optional department binding.
|
||||
|
||||
Args:
|
||||
filename: Original filename.
|
||||
content: File content bytes (used to estimate chunk count).
|
||||
source_id: Optional source id. Defaults to ``"local"``.
|
||||
department_id: Optional department id to bind the document to.
|
||||
``None`` means the document is global (visible to all).
|
||||
|
||||
Returns:
|
||||
The uploaded document dict.
|
||||
"""
|
||||
store = self._resolve_store()
|
||||
effective_source_id = source_id or "local"
|
||||
|
||||
# Ensure a local source exists for default uploads.
|
||||
if effective_source_id == "local":
|
||||
local_sources = [s for s in store.list_sources() if s.type == "local"]
|
||||
if not local_sources:
|
||||
store.add_source("本地文档", "local", {})
|
||||
|
||||
# Estimate chunks from content length (rough approximation,
|
||||
# mirrors the existing /kb-management/documents/upload route).
|
||||
chunks = max(1, len(content) // 500)
|
||||
|
||||
doc = UploadedDocument(
|
||||
document_id=str(uuid.uuid4()),
|
||||
filename=filename,
|
||||
source_id=effective_source_id,
|
||||
chunks=chunks,
|
||||
status="indexed",
|
||||
department_id=department_id,
|
||||
)
|
||||
store.add_document(doc)
|
||||
return self._doc_to_dict(doc)
|
||||
|
||||
def delete_document(self, document_id: str) -> bool:
|
||||
"""Delete a document by id. Returns ``True`` if deleted."""
|
||||
store = self._resolve_store()
|
||||
return store.delete_document(document_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Source operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def list_sources(self) -> list[dict[str, Any]]:
|
||||
"""List all KB sources."""
|
||||
store = self._resolve_store()
|
||||
return [self._source_to_dict(s) for s in store.list_sources()]
|
||||
|
||||
def get_source(self, source_id: str) -> dict[str, Any] | None:
|
||||
"""Return a single source by id, or ``None`` if not found."""
|
||||
store = self._resolve_store()
|
||||
source = store.get_source(source_id)
|
||||
return self._source_to_dict(source) if source else None
|
||||
|
||||
def sync_source(self, source_id: str) -> dict[str, Any]:
|
||||
"""Trigger a source sync (stub — marks status as ``syncing``).
|
||||
|
||||
Raises:
|
||||
ValueError: If the source does not exist.
|
||||
"""
|
||||
store = self._resolve_store()
|
||||
source = store.get_source(source_id)
|
||||
if source is None:
|
||||
raise ValueError(f"KB source {source_id!r} not found")
|
||||
# Stub: mark status as syncing. A real implementation would
|
||||
# kick off an async sync job and return a job id.
|
||||
source.status = "syncing"
|
||||
source.last_synced = _now_iso()
|
||||
return {
|
||||
"source_id": source_id,
|
||||
"status": "syncing",
|
||||
"message": "Sync started",
|
||||
}
|
||||
|
||||
def rebuild_index(self, source_id: str) -> dict[str, Any]:
|
||||
"""Rebuild the index for a source (stub — marks status).
|
||||
|
||||
Raises:
|
||||
ValueError: If the source does not exist.
|
||||
"""
|
||||
store = self._resolve_store()
|
||||
source = store.get_source(source_id)
|
||||
if source is None:
|
||||
raise ValueError(f"KB source {source_id!r} not found")
|
||||
# Stub: mark status as rebuilding.
|
||||
source.status = "rebuilding"
|
||||
return {
|
||||
"source_id": source_id,
|
||||
"status": "rebuilding",
|
||||
"message": "Index rebuild started",
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Serialization helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _doc_to_dict(self, doc: UploadedDocument) -> dict[str, Any]:
|
||||
"""Convert an :class:`UploadedDocument` to a JSON-safe dict."""
|
||||
return {
|
||||
"document_id": doc.document_id,
|
||||
"filename": doc.filename,
|
||||
"source_id": doc.source_id,
|
||||
"chunks": doc.chunks,
|
||||
"status": doc.status,
|
||||
"created_at": doc.created_at,
|
||||
"department_id": doc.department_id,
|
||||
}
|
||||
|
||||
def _source_to_dict(self, source: KnowledgeSource) -> dict[str, Any]:
|
||||
"""Convert a :class:`KnowledgeSource` to a JSON-safe dict."""
|
||||
return {
|
||||
"id": source.id,
|
||||
"name": source.name,
|
||||
"type": source.type,
|
||||
"status": source.status,
|
||||
"document_count": source.document_count,
|
||||
"last_synced": source.last_synced,
|
||||
"department_id": source.department_id,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level singleton (overridable in tests via set_kb_service)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_kb_service: KbService | None = None
|
||||
|
||||
|
||||
def get_kb_service() -> KbService:
|
||||
"""Return the process-wide :class:`KbService` (lazy singleton)."""
|
||||
global _kb_service
|
||||
if _kb_service is None:
|
||||
_kb_service = KbService()
|
||||
return _kb_service
|
||||
|
||||
|
||||
def set_kb_service(service: KbService | None) -> None:
|
||||
"""Inject a custom :class:`KbService` (used by tests)."""
|
||||
global _kb_service
|
||||
_kb_service = service
|
||||
|
|
@ -0,0 +1,428 @@
|
|||
"""SkillService — admin-driven skill enable/disable + import/reload (U6).
|
||||
|
||||
This module is the single owner of the ``skill_states`` table (admin
|
||||
disable markers) and the YAML import/reload pipeline. Web UI routes
|
||||
(``/api/v1/admin/skills/*``) and the existing skill-management route
|
||||
both call into :class:`SkillService` rather than touching the table
|
||||
or filesystem directly, keeping validation rules (skill-name regex,
|
||||
path-traversal guard, YAML validation) in one place.
|
||||
|
||||
The service is a module-level singleton (see :func:`get_skill_service`)
|
||||
so tests can inject a custom instance via :func:`set_skill_service`.
|
||||
|
||||
Design notes
|
||||
------------
|
||||
- Each method opens its own short-lived :class:`aiosqlite.Connection`
|
||||
(mirrors :class:`DepartmentService`).
|
||||
- ``import_skill`` reuses the validation regex from
|
||||
:mod:`agentkit.server.routes.skills` but inlines a copy to avoid a
|
||||
circular import (routes.skills → admin.skill_service → routes.skills).
|
||||
- ``reload_skill`` and ``update_skill_config`` accept a
|
||||
:class:`SkillRegistry` instance so the caller can pass the live
|
||||
registry from ``app.state``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import aiosqlite
|
||||
import yaml
|
||||
|
||||
from agentkit.server.auth.models import skill_state_row_to_dict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants & helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Strict skill name validation: lowercase alphanumeric, hyphens, underscores.
|
||||
# Inlined from routes/skills.py to avoid a circular import.
|
||||
_SKILL_NAME_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
"""Return current UTC time as ISO 8601 string."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _validate_skill_name(name: str) -> str:
|
||||
"""Validate and normalize a skill name. Raises ``ValueError`` on invalid input.
|
||||
|
||||
Mirrors the regex in :mod:`agentkit.server.routes.skills` but raises
|
||||
``ValueError`` (not ``HTTPException``) so the service is HTTP-agnostic.
|
||||
"""
|
||||
if not isinstance(name, str):
|
||||
raise ValueError("Skill name must be a string")
|
||||
normalized = name.strip().lower()
|
||||
if not _SKILL_NAME_RE.match(normalized):
|
||||
raise ValueError(
|
||||
f"Invalid skill name {name!r}: must contain only lowercase letters, "
|
||||
f"digits, hyphens, and underscores (1-64 chars)"
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
def _validate_yaml_content(content: str) -> dict[str, Any]:
|
||||
"""Validate YAML content. Returns parsed dict. Raises ``ValueError`` on invalid input.
|
||||
|
||||
Mirrors the validation in :mod:`agentkit.server.routes.skills` but
|
||||
raises ``ValueError`` (not ``HTTPException``).
|
||||
"""
|
||||
try:
|
||||
data = yaml.safe_load(content)
|
||||
except yaml.YAMLError as exc:
|
||||
raise ValueError(f"Invalid YAML content: {exc}") from exc
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Skill YAML must be a mapping/dict")
|
||||
|
||||
if "name" not in data:
|
||||
raise ValueError("Skill YAML must contain a 'name' field")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _is_within_directory(file_path: str, directory: str) -> bool:
|
||||
"""Return ``True`` if ``file_path`` resolves inside ``directory``.
|
||||
|
||||
Uses :func:`os.path.realpath` to canonicalize both paths before
|
||||
comparison, so symlink traversal and ``..`` segments are detected.
|
||||
"""
|
||||
real_file = os.path.realpath(file_path)
|
||||
real_dir = os.path.realpath(directory)
|
||||
return real_file.startswith(real_dir + os.sep) or real_file == real_dir
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SkillService:
|
||||
"""Admin-driven skill enable/disable + import/reload operations.
|
||||
|
||||
All DB-touching methods are async and take ``db_path: Path`` as the
|
||||
first argument. YAML-touching methods take ``skills_dir: str`` and
|
||||
optionally a :class:`SkillRegistry` instance for live reload.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Enable / disable (DB state)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def disable_skill(
|
||||
self,
|
||||
db_path: Path,
|
||||
skill_name: str,
|
||||
disabled_by: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Mark a skill as disabled in the DB.
|
||||
|
||||
Idempotent: if the skill is already disabled, the row is updated
|
||||
(refreshing ``disabled_at`` and ``disabled_by``).
|
||||
|
||||
Args:
|
||||
db_path: Path to the auth SQLite DB.
|
||||
skill_name: Skill name to disable (validated against the regex).
|
||||
disabled_by: Optional admin user id (audit trail).
|
||||
|
||||
Returns:
|
||||
The disabled skill state dict.
|
||||
"""
|
||||
normalized = _validate_skill_name(skill_name)
|
||||
now = _now_iso()
|
||||
async with aiosqlite.connect(str(db_path)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
await db.execute(
|
||||
"INSERT INTO skill_states (skill_name, is_disabled, disabled_at, disabled_by) "
|
||||
"VALUES (?, 1, ?, ?) "
|
||||
"ON CONFLICT(skill_name) DO UPDATE SET "
|
||||
" is_disabled = excluded.is_disabled, "
|
||||
" disabled_at = excluded.disabled_at, "
|
||||
" disabled_by = excluded.disabled_by",
|
||||
(normalized, now, disabled_by),
|
||||
)
|
||||
await db.commit()
|
||||
cursor = await db.execute(
|
||||
"SELECT * FROM skill_states WHERE skill_name = ?",
|
||||
(normalized,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None # we just upserted it
|
||||
return skill_state_row_to_dict(row)
|
||||
|
||||
async def enable_skill(
|
||||
self,
|
||||
db_path: Path,
|
||||
skill_name: str,
|
||||
) -> bool:
|
||||
"""Remove the disabled mark for a skill.
|
||||
|
||||
Returns:
|
||||
``True`` if a row was deleted (skill was previously disabled),
|
||||
``False`` if the skill was not disabled.
|
||||
"""
|
||||
normalized = _validate_skill_name(skill_name)
|
||||
async with aiosqlite.connect(str(db_path)) as db:
|
||||
cursor = await db.execute(
|
||||
"DELETE FROM skill_states WHERE skill_name = ?",
|
||||
(normalized,),
|
||||
)
|
||||
await db.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
async def is_skill_disabled(
|
||||
self,
|
||||
db_path: Path,
|
||||
skill_name: str,
|
||||
) -> bool:
|
||||
"""Return ``True`` if the skill is currently disabled."""
|
||||
normalized = _validate_skill_name(skill_name)
|
||||
async with aiosqlite.connect(str(db_path)) as db:
|
||||
cursor = await db.execute(
|
||||
"SELECT is_disabled FROM skill_states WHERE skill_name = ?",
|
||||
(normalized,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return False
|
||||
return bool(row[0])
|
||||
|
||||
async def list_disabled_skills(
|
||||
self,
|
||||
db_path: Path,
|
||||
) -> list[str]:
|
||||
"""Return the list of skill names currently disabled."""
|
||||
async with aiosqlite.connect(str(db_path)) as db:
|
||||
cursor = await db.execute(
|
||||
"SELECT skill_name FROM skill_states WHERE is_disabled = 1 ORDER BY skill_name ASC"
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# YAML import / reload / update
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def import_skill(
|
||||
self,
|
||||
yaml_content: str,
|
||||
skills_dir: str,
|
||||
skill_registry: SkillRegistry | None = None,
|
||||
tool_registry: Any = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Validate YAML, write to file, and load into the registry.
|
||||
|
||||
Args:
|
||||
yaml_content: Raw YAML text for the skill config.
|
||||
skills_dir: Target directory for the YAML file. The file is
|
||||
written as ``{skills_dir}/{skill_name}.yaml``.
|
||||
skill_registry: Optional :class:`SkillRegistry` to register
|
||||
the loaded skill into. If ``None``, the skill is only
|
||||
written to disk.
|
||||
tool_registry: Optional :class:`ToolRegistry` for tool binding.
|
||||
|
||||
Returns:
|
||||
Dict with ``name`` and ``path`` of the imported skill.
|
||||
|
||||
Raises:
|
||||
ValueError: If the YAML is invalid, the skill name is invalid,
|
||||
or the resolved path escapes ``skills_dir``.
|
||||
"""
|
||||
data = _validate_yaml_content(yaml_content)
|
||||
raw_name = data["name"]
|
||||
skill_name = _validate_skill_name(raw_name)
|
||||
|
||||
os.makedirs(skills_dir, exist_ok=True)
|
||||
file_path = os.path.join(skills_dir, f"{skill_name}.yaml")
|
||||
|
||||
# Path traversal protection: the resolved path must stay inside
|
||||
# skills_dir. (skill_name is regex-validated so this should always
|
||||
# pass, but we double-check defensively.)
|
||||
if not _is_within_directory(file_path, skills_dir):
|
||||
raise ValueError(f"Resolved skill path escapes skills_dir: {file_path}")
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(yaml_content)
|
||||
|
||||
if skill_registry is not None:
|
||||
try:
|
||||
from agentkit.skills.loader import SkillLoader
|
||||
|
||||
loader = SkillLoader(
|
||||
skill_registry=skill_registry,
|
||||
tool_registry=tool_registry,
|
||||
)
|
||||
loader.load_from_file(file_path)
|
||||
except Exception:
|
||||
# Remove the invalid YAML file and re-raise as ValueError
|
||||
# so the route layer maps it to 400.
|
||||
try:
|
||||
os.remove(file_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise ValueError("Skill YAML written but registration failed")
|
||||
|
||||
return {"name": skill_name, "path": file_path}
|
||||
|
||||
async def reload_skill(
|
||||
self,
|
||||
skill_name: str,
|
||||
skill_registry: SkillRegistry,
|
||||
skills_dir: str,
|
||||
tool_registry: Any = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Unregister and reload a skill from its YAML file.
|
||||
|
||||
Args:
|
||||
skill_name: Skill name to reload (validated against the regex).
|
||||
skill_registry: The live :class:`SkillRegistry` instance.
|
||||
skills_dir: Directory containing the YAML file.
|
||||
tool_registry: Optional :class:`ToolRegistry` for tool binding.
|
||||
|
||||
Returns:
|
||||
Dict with ``name``, ``path``, and ``status`` of the reload.
|
||||
|
||||
Raises:
|
||||
ValueError: If the skill name is invalid or the YAML file
|
||||
cannot be found.
|
||||
"""
|
||||
normalized = _validate_skill_name(skill_name)
|
||||
file_path = os.path.join(skills_dir, f"{normalized}.yaml")
|
||||
if not os.path.isfile(file_path):
|
||||
raise ValueError(f"Skill {normalized!r} not found: no YAML file at {file_path}")
|
||||
|
||||
if not _is_within_directory(file_path, skills_dir):
|
||||
raise ValueError(f"Resolved skill path escapes skills_dir: {file_path}")
|
||||
|
||||
# Unregister the existing skill (if any) before reloading so the
|
||||
# registry picks up the new config cleanly.
|
||||
try:
|
||||
skill_registry.unregister(normalized)
|
||||
except Exception: # noqa: BLE001 — unregister is best-effort
|
||||
logger.debug("Skill %r was not registered before reload", normalized)
|
||||
|
||||
from agentkit.skills.loader import SkillLoader
|
||||
|
||||
loader = SkillLoader(
|
||||
skill_registry=skill_registry,
|
||||
tool_registry=tool_registry,
|
||||
)
|
||||
loader.load_from_file(file_path)
|
||||
|
||||
return {
|
||||
"name": normalized,
|
||||
"path": file_path,
|
||||
"status": "reloaded",
|
||||
}
|
||||
|
||||
async def update_skill_config(
|
||||
self,
|
||||
skill_name: str,
|
||||
config_patch: dict[str, Any],
|
||||
skills_dir: str,
|
||||
skill_registry: SkillRegistry,
|
||||
tool_registry: Any = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update a skill's YAML config in-place and reload it.
|
||||
|
||||
Performs a shallow merge: top-level keys in ``config_patch``
|
||||
overwrite the existing YAML values. Nested dicts are replaced
|
||||
wholesale (not deep-merged) to keep the semantics predictable.
|
||||
|
||||
Args:
|
||||
skill_name: Skill name to update (validated against the regex).
|
||||
config_patch: Dict of top-level YAML keys to update.
|
||||
skills_dir: Directory containing the YAML file.
|
||||
skill_registry: The live :class:`SkillRegistry` instance.
|
||||
tool_registry: Optional :class:`ToolRegistry` for tool binding.
|
||||
|
||||
Returns:
|
||||
Dict with ``name``, ``path``, and ``status`` of the update.
|
||||
|
||||
Raises:
|
||||
ValueError: If the skill name is invalid, the YAML file
|
||||
cannot be found, or the patched config is invalid.
|
||||
"""
|
||||
normalized = _validate_skill_name(skill_name)
|
||||
file_path = os.path.join(skills_dir, f"{normalized}.yaml")
|
||||
if not os.path.isfile(file_path):
|
||||
raise ValueError(f"Skill {normalized!r} not found: no YAML file at {file_path}")
|
||||
|
||||
if not _is_within_directory(file_path, skills_dir):
|
||||
raise ValueError(f"Resolved skill path escapes skills_dir: {file_path}")
|
||||
|
||||
# Read existing YAML, apply patch, validate, write back.
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
existing = yaml.safe_load(f) or {}
|
||||
|
||||
if not isinstance(existing, dict):
|
||||
raise ValueError(f"Existing skill YAML at {file_path} is not a mapping")
|
||||
|
||||
# Shallow merge: top-level keys in config_patch overwrite.
|
||||
# The 'name' field is preserved (cannot rename via patch).
|
||||
patched = {**existing, **config_patch, "name": existing.get("name", normalized)}
|
||||
|
||||
# Validate the patched config by round-tripping through YAML.
|
||||
try:
|
||||
patched_yaml = yaml.safe_dump(patched, default_flow_style=False, allow_unicode=True)
|
||||
except yaml.YAMLError as exc:
|
||||
raise ValueError(f"Failed to serialize patched config: {exc}") from exc
|
||||
|
||||
_validate_yaml_content(patched_yaml)
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(patched_yaml)
|
||||
|
||||
# Reload the skill into the registry.
|
||||
try:
|
||||
skill_registry.unregister(normalized)
|
||||
except Exception: # noqa: BLE001 — unregister is best-effort
|
||||
logger.debug("Skill %r was not registered before update", normalized)
|
||||
|
||||
from agentkit.skills.loader import SkillLoader
|
||||
|
||||
loader = SkillLoader(
|
||||
skill_registry=skill_registry,
|
||||
tool_registry=tool_registry,
|
||||
)
|
||||
loader.load_from_file(file_path)
|
||||
|
||||
return {
|
||||
"name": normalized,
|
||||
"path": file_path,
|
||||
"status": "updated",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level singleton (overridable in tests via set_skill_service)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_skill_service: SkillService | None = None
|
||||
|
||||
|
||||
def get_skill_service() -> SkillService:
|
||||
"""Return the process-wide :class:`SkillService` (lazy singleton)."""
|
||||
global _skill_service
|
||||
if _skill_service is None:
|
||||
_skill_service = SkillService()
|
||||
return _skill_service
|
||||
|
||||
|
||||
def set_skill_service(service: SkillService | None) -> None:
|
||||
"""Inject a custom :class:`SkillService` (used by tests)."""
|
||||
global _skill_service
|
||||
_skill_service = service
|
||||
|
|
@ -351,6 +351,24 @@ class DepartmentQuotaModel(Base):
|
|||
updated_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso)
|
||||
|
||||
|
||||
class SkillStateModel(Base):
|
||||
"""Skill enable/disable state (V4 — Admin Console U6).
|
||||
|
||||
Each row records whether a named skill has been disabled by an
|
||||
admin via the ``/admin/skills/{name}/disable`` endpoint. Skills
|
||||
with no row here are considered enabled (the default). ``skill_name``
|
||||
references the skill registry identifier (not a DB FK — skills are
|
||||
defined in YAML configs).
|
||||
"""
|
||||
|
||||
__tablename__ = "skill_states"
|
||||
|
||||
skill_name: Mapped[str] = mapped_column(String(128), primary_key=True)
|
||||
is_disabled: Mapped[bool] = mapped_column(default=True, nullable=False)
|
||||
disabled_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso)
|
||||
disabled_by: Mapped[str | None] = mapped_column(String(36), nullable=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema DDL (kept in sync with the models above for aiosqlite bootstrap)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -565,6 +583,17 @@ CREATE TABLE IF NOT EXISTS department_quotas (
|
|||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_department_quotas_department_id
|
||||
ON department_quotas(department_id);
|
||||
|
||||
-- V4: skill_states records admin-disabled skills (U6 — Admin Console).
|
||||
-- A skill with no row here is considered enabled (the default). Only
|
||||
-- disabled skills have a row, with is_disabled=1. disabled_by records
|
||||
-- the admin user id who disabled the skill (audit trail).
|
||||
CREATE TABLE IF NOT EXISTS skill_states (
|
||||
skill_name TEXT PRIMARY KEY,
|
||||
is_disabled INTEGER NOT NULL DEFAULT 1,
|
||||
disabled_at TEXT NOT NULL,
|
||||
disabled_by TEXT
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
|
|
@ -580,7 +609,10 @@ CREATE INDEX IF NOT EXISTS idx_department_quotas_department_id
|
|||
# V3 (2026-06-21, Admin Console): added departments, user_departments,
|
||||
# department_skill_bindings, department_kb_bindings, department_quotas.
|
||||
# No backfill needed — all new tables are additive.
|
||||
_SCHEMA_VERSION = 3
|
||||
#
|
||||
# V4 (2026-06-21, Admin Console U6): added skill_states table for
|
||||
# admin-driven skill enable/disable. No backfill needed — additive.
|
||||
_SCHEMA_VERSION = 4
|
||||
|
||||
_META_SCHEMA_VERSION_KEY = "schema_version"
|
||||
|
||||
|
|
@ -803,3 +835,18 @@ def user_department_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> di
|
|||
"department_id": row["department_id"],
|
||||
"created_at": row["created_at"],
|
||||
}
|
||||
|
||||
|
||||
def skill_state_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any]:
|
||||
"""Convert a ``skill_states`` row into a JSON-safe dict.
|
||||
|
||||
The ``is_disabled`` field is normalized to a Python ``bool`` (the DB
|
||||
stores 0/1). ``disabled_by`` is ``None`` when the disabling admin
|
||||
is not recorded (e.g. legacy rows or system-initiated disables).
|
||||
"""
|
||||
return {
|
||||
"skill_name": row["skill_name"],
|
||||
"is_disabled": bool(row["is_disabled"]),
|
||||
"disabled_at": row["disabled_at"],
|
||||
"disabled_by": row["disabled_by"],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,11 +23,13 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
|||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from agentkit.server.admin.department_service import get_department_service
|
||||
from agentkit.server.admin.kb_service import get_kb_service
|
||||
from agentkit.server.admin.llm_config_service import (
|
||||
LlmConfigService,
|
||||
get_llm_config_service,
|
||||
)
|
||||
from agentkit.server.admin.quota_service import get_quota_service
|
||||
from agentkit.server.admin.skill_service import get_skill_service
|
||||
from agentkit.server.admin.user_service import get_user_service
|
||||
from agentkit.server.auth.dependencies import require_authenticated
|
||||
from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH, init_auth_db
|
||||
|
|
@ -854,3 +856,285 @@ async def delete_department_quota(
|
|||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Quota not found")
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skill management endpoints (U6) — enable/disable + import/reload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_skills_dir(request: Request) -> str:
|
||||
"""Resolve the skills directory from ``app.state.server_config``.
|
||||
|
||||
Mirrors the logic in :func:`agentkit.server.routes.skills._get_skills_dir`
|
||||
so admin endpoints install/reload skills into the same directory as
|
||||
the existing ``/skills/install`` route.
|
||||
"""
|
||||
server_config = getattr(request.app.state, "server_config", None)
|
||||
if server_config and getattr(server_config, "skill_paths", None):
|
||||
first_path = Path(server_config.skill_paths[0])
|
||||
if first_path.is_dir():
|
||||
return str(first_path)
|
||||
# Fallback: configs/skills/ relative to cwd (matches routes.skills).
|
||||
import os
|
||||
|
||||
return os.path.join(os.getcwd(), "configs", "skills")
|
||||
|
||||
|
||||
def _get_skill_registry(request: Request) -> Any:
|
||||
"""Return the live :class:`SkillRegistry` from ``app.state``.
|
||||
|
||||
Raises HTTPException(500) if the registry is missing — admin skill
|
||||
endpoints cannot function without it.
|
||||
"""
|
||||
registry = getattr(request.app.state, "skill_registry", None)
|
||||
if registry is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Skill registry not initialized on app.state",
|
||||
)
|
||||
return registry
|
||||
|
||||
|
||||
class SkillImportRequest(BaseModel):
|
||||
"""Body for ``POST /admin/skills/import``."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
yaml_content: str
|
||||
|
||||
|
||||
class SkillUpdateRequest(BaseModel):
|
||||
"""Body for ``PATCH /admin/skills/{name}``."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
@admin_router.post("/skills/{name}/enable")
|
||||
async def enable_skill(
|
||||
name: str,
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""Enable a previously-disabled skill.
|
||||
|
||||
Returns 200 ``{enabled: true}`` on success. Idempotent — returns
|
||||
200 even if the skill was not disabled.
|
||||
"""
|
||||
db_path = await _ensure_db(request)
|
||||
svc = get_skill_service()
|
||||
try:
|
||||
enabled = await svc.enable_skill(db_path, name)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
return {"enabled": enabled, "skill_name": name}
|
||||
|
||||
|
||||
@admin_router.post("/skills/{name}/disable")
|
||||
async def disable_skill(
|
||||
name: str,
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""Disable a skill (hides it from ``GET /skills``).
|
||||
|
||||
Returns 200 with the disabled skill state dict.
|
||||
"""
|
||||
db_path = await _ensure_db(request)
|
||||
svc = get_skill_service()
|
||||
try:
|
||||
return await svc.disable_skill(db_path, name, disabled_by=admin.get("user_id"))
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@admin_router.patch("/skills/{name}")
|
||||
async def update_skill(
|
||||
name: str,
|
||||
payload: SkillUpdateRequest,
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""Update a skill's YAML config in-place and reload it.
|
||||
|
||||
Returns 200 with the updated skill info. Returns 404 if the skill
|
||||
YAML file does not exist, 400 if the patched config is invalid.
|
||||
"""
|
||||
skills_dir = _get_skills_dir(request)
|
||||
registry = _get_skill_registry(request)
|
||||
svc = get_skill_service()
|
||||
try:
|
||||
return await svc.update_skill_config(
|
||||
name,
|
||||
payload.config,
|
||||
skills_dir,
|
||||
registry,
|
||||
)
|
||||
except ValueError as exc:
|
||||
msg = str(exc)
|
||||
if "not found" in msg:
|
||||
raise HTTPException(status_code=404, detail=msg) from exc
|
||||
raise HTTPException(status_code=400, detail=msg) from exc
|
||||
|
||||
|
||||
@admin_router.post("/skills/import")
|
||||
async def import_skill(
|
||||
payload: SkillImportRequest,
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""Import a skill from YAML content.
|
||||
|
||||
Writes the YAML to ``{skills_dir}/{name}.yaml`` and registers it
|
||||
in the live :class:`SkillRegistry`. Returns 200 with the imported
|
||||
skill info. Returns 400 if the YAML is invalid or the skill name
|
||||
is invalid.
|
||||
"""
|
||||
skills_dir = _get_skills_dir(request)
|
||||
registry = _get_skill_registry(request)
|
||||
svc = get_skill_service()
|
||||
try:
|
||||
return await svc.import_skill(
|
||||
payload.yaml_content,
|
||||
skills_dir,
|
||||
skill_registry=registry,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@admin_router.post("/skills/{name}/reload")
|
||||
async def reload_skill(
|
||||
name: str,
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""Reload a skill from its YAML file.
|
||||
|
||||
Returns 200 with the reloaded skill info. Returns 404 if the YAML
|
||||
file does not exist, 400 if the skill name is invalid.
|
||||
"""
|
||||
skills_dir = _get_skills_dir(request)
|
||||
registry = _get_skill_registry(request)
|
||||
svc = get_skill_service()
|
||||
try:
|
||||
return await svc.reload_skill(name, registry, skills_dir)
|
||||
except ValueError as exc:
|
||||
msg = str(exc)
|
||||
if "not found" in msg:
|
||||
raise HTTPException(status_code=404, detail=msg) from exc
|
||||
raise HTTPException(status_code=400, detail=msg) from exc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# KB management endpoints (U6) — documents + sources
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class KbDocumentUploadRequest(BaseModel):
|
||||
"""Body for ``POST /admin/kb/documents``."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
filename: str
|
||||
content: str
|
||||
source_id: str = ""
|
||||
department_id: str | None = None
|
||||
|
||||
|
||||
@admin_router.get("/kb/documents")
|
||||
async def list_kb_documents(
|
||||
request: Request,
|
||||
source_id: str | None = None,
|
||||
department_id: str | None = None,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""List KB documents (admin sees all).
|
||||
|
||||
Query params: ``source_id`` (optional filter), ``department_id``
|
||||
(optional filter — when set, only documents bound to that
|
||||
department or global documents are returned).
|
||||
"""
|
||||
svc = get_kb_service()
|
||||
# Admin sees everything. If department_id is provided as a query
|
||||
# filter, narrow to that department + global docs.
|
||||
if department_id is not None:
|
||||
dept_ids = [department_id]
|
||||
else:
|
||||
dept_ids = None # admin: no filtering
|
||||
documents = svc.list_documents(department_ids=dept_ids, source_id=source_id)
|
||||
return {"documents": documents}
|
||||
|
||||
|
||||
@admin_router.post("/kb/documents", status_code=201)
|
||||
async def upload_kb_document(
|
||||
payload: KbDocumentUploadRequest,
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""Upload a KB document with optional department binding.
|
||||
|
||||
Returns 201 with the uploaded document dict.
|
||||
"""
|
||||
svc = get_kb_service()
|
||||
return svc.upload_document(
|
||||
filename=payload.filename,
|
||||
content=payload.content.encode("utf-8"),
|
||||
source_id=payload.source_id,
|
||||
department_id=payload.department_id,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.delete("/kb/documents/{document_id}")
|
||||
async def delete_kb_document(
|
||||
document_id: str,
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""Delete a KB document by id.
|
||||
|
||||
Returns 200 ``{deleted: true}`` on success, 404 if not found.
|
||||
"""
|
||||
svc = get_kb_service()
|
||||
deleted = svc.delete_document(document_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
@admin_router.post("/kb/sources/{source_id}/sync")
|
||||
async def sync_kb_source(
|
||||
source_id: str,
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""Trigger a sync for a KB source.
|
||||
|
||||
Returns 200 with the sync status. Returns 404 if the source does
|
||||
not exist.
|
||||
"""
|
||||
svc = get_kb_service()
|
||||
try:
|
||||
return svc.sync_source(source_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@admin_router.post("/kb/sources/{source_id}/rebuild")
|
||||
async def rebuild_kb_source(
|
||||
source_id: str,
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""Rebuild the index for a KB source.
|
||||
|
||||
Returns 200 with the rebuild status. Returns 404 if the source
|
||||
does not exist.
|
||||
"""
|
||||
svc = get_kb_service()
|
||||
try:
|
||||
return svc.rebuild_index(source_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ class KnowledgeSource:
|
|||
status: str = "active"
|
||||
document_count: int = 0
|
||||
last_synced: str | None = None
|
||||
department_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -83,6 +84,7 @@ class UploadedDocument:
|
|||
chunks: int
|
||||
status: str
|
||||
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
department_id: str | None = None
|
||||
|
||||
|
||||
class KnowledgeSourceStore:
|
||||
|
|
@ -155,6 +157,23 @@ class KnowledgeSourceStore:
|
|||
_source_store = KnowledgeSourceStore()
|
||||
|
||||
|
||||
def get_source_store() -> KnowledgeSourceStore:
|
||||
"""Return the process-wide :class:`KnowledgeSourceStore` singleton.
|
||||
|
||||
Exposed as a function (rather than importing ``_source_store``
|
||||
directly) so tests can swap it out via monkeypatch, and so the
|
||||
:class:`KbService` wrapper can access the store without reaching
|
||||
into a private module attribute.
|
||||
"""
|
||||
return _source_store
|
||||
|
||||
|
||||
def set_source_store(store: KnowledgeSourceStore) -> None:
|
||||
"""Inject a custom :class:`KnowledgeSourceStore` (used by tests)."""
|
||||
global _source_store
|
||||
_source_store = store
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / Response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -54,9 +54,7 @@ def _skill_to_info(skill: Any) -> dict[str, Any]:
|
|||
if hasattr(skill, "config") and hasattr(skill.config, "capabilities"):
|
||||
caps = skill.config.capabilities
|
||||
if isinstance(caps, list):
|
||||
capabilities = [
|
||||
c.tag if hasattr(c, "tag") else str(c) for c in caps
|
||||
]
|
||||
capabilities = [c.tag if hasattr(c, "tag") else str(c) for c in caps]
|
||||
elif isinstance(caps, dict):
|
||||
capabilities = list(caps.keys())
|
||||
|
||||
|
|
@ -160,7 +158,7 @@ async def check_skill_health(skill_name: str, req: Request):
|
|||
"""Check the health of a specific skill."""
|
||||
skill_registry = req.app.state.skill_registry
|
||||
try:
|
||||
skill = skill_registry.get(skill_name)
|
||||
skill_registry.get(skill_name)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||
|
||||
|
|
@ -213,17 +211,43 @@ async def list_capabilities(req: Request):
|
|||
|
||||
@router.post("/skill-management/skills/{skill_name}/reload")
|
||||
async def reload_skill(skill_name: str, req: Request):
|
||||
"""Reload a skill configuration."""
|
||||
"""Reload a skill configuration from its YAML file."""
|
||||
skill_registry = req.app.state.skill_registry
|
||||
# Verify the skill is currently registered (404 if not).
|
||||
try:
|
||||
skill = skill_registry.get(skill_name)
|
||||
skill_registry.get(skill_name)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||
|
||||
# In a full implementation, this would reload the skill from its config source
|
||||
# For now, just return success
|
||||
# Resolve the skills directory (mirrors routes.skills._get_skills_dir).
|
||||
import os
|
||||
|
||||
skills_dir: str
|
||||
server_config = getattr(req.app.state, "server_config", None)
|
||||
if server_config and getattr(server_config, "skill_paths", None):
|
||||
from pathlib import Path as _P
|
||||
|
||||
first_path = _P(server_config.skill_paths[0])
|
||||
if first_path.is_dir():
|
||||
skills_dir = str(first_path)
|
||||
else:
|
||||
skills_dir = os.path.join(os.getcwd(), "configs", "skills")
|
||||
else:
|
||||
skills_dir = os.path.join(os.getcwd(), "configs", "skills")
|
||||
|
||||
from agentkit.server.admin.skill_service import get_skill_service
|
||||
|
||||
svc = get_skill_service()
|
||||
try:
|
||||
result = await svc.reload_skill(skill_name, skill_registry, skills_dir)
|
||||
except ValueError as exc:
|
||||
msg = str(exc)
|
||||
if "not found" in msg:
|
||||
raise HTTPException(status_code=404, detail=msg) from exc
|
||||
raise HTTPException(status_code=400, detail=msg) from exc
|
||||
|
||||
return {
|
||||
"skill_name": skill_name,
|
||||
"status": "reloaded",
|
||||
"skill_name": result["name"],
|
||||
"status": result["status"],
|
||||
"message": f"技能 '{skill_name}' 已重新加载",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ def _get_skills_dir(req: Request) -> str:
|
|||
if server_config and server_config.skill_paths:
|
||||
# Use the first configured skill path as the install target
|
||||
from pathlib import Path as _P
|
||||
|
||||
first_path = _P(server_config.skill_paths[0])
|
||||
if first_path.is_dir():
|
||||
return str(first_path)
|
||||
|
|
@ -59,12 +60,16 @@ def _get_skills_dir(req: Request) -> str:
|
|||
def _validate_source_url(source: str) -> None:
|
||||
"""Validate that a source URL points to an allowed domain (SSRF mitigation)."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(source)
|
||||
if parsed.scheme not in ("https", "http"):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid source URL scheme: only http/https allowed")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid source URL scheme: only http/https allowed"
|
||||
)
|
||||
# Block private/internal IPs by checking hostname
|
||||
import ipaddress
|
||||
import socket
|
||||
|
||||
hostname = parsed.hostname
|
||||
if hostname:
|
||||
try:
|
||||
|
|
@ -82,12 +87,15 @@ def _validate_source_url(source: str) -> None:
|
|||
# Check domain allowlist for source URLs
|
||||
if hostname and hostname not in _ALLOWED_DOWNLOAD_DOMAINS:
|
||||
# Allow but log a warning for non-allowlisted domains
|
||||
logger.warning(f"Source URL domain '{hostname}' is not in the allowlist: {_ALLOWED_DOWNLOAD_DOMAINS}")
|
||||
logger.warning(
|
||||
f"Source URL domain '{hostname}' is not in the allowlist: {_ALLOWED_DOWNLOAD_DOMAINS}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_yaml_content(content: str) -> dict:
|
||||
"""Validate YAML content before writing to disk. Returns parsed dict."""
|
||||
import yaml
|
||||
|
||||
try:
|
||||
data = yaml.safe_load(content)
|
||||
except yaml.YAMLError as e:
|
||||
|
|
@ -156,16 +164,37 @@ async def list_skills(
|
|||
Admin users (``role == "admin"``) bypass filtering and see all
|
||||
registered skills. Unauthenticated callers (API-key clients) see
|
||||
only global skills.
|
||||
|
||||
Disabled-skill filtering (U6): skills marked as disabled by an
|
||||
admin via ``POST /admin/skills/{name}/disable`` are excluded from
|
||||
the response for ALL callers (admins included). This keeps the
|
||||
public skill list in sync with the admin enable/disable state.
|
||||
"""
|
||||
skill_registry = req.app.state.skill_registry
|
||||
skills = skill_registry.list_skills()
|
||||
|
||||
# U6: filter out admin-disabled skills (applies to all callers).
|
||||
db_path = _resolve_db_path(req)
|
||||
if db_path.exists():
|
||||
try:
|
||||
from agentkit.server.admin.skill_service import get_skill_service
|
||||
|
||||
svc = get_skill_service()
|
||||
disabled_names = set(await svc.list_disabled_skills(db_path))
|
||||
except Exception: # noqa: BLE001 — never block listing on DB errors
|
||||
logger.exception("Failed to load disabled skills list — skipping filter")
|
||||
disabled_names = set()
|
||||
else:
|
||||
disabled_names = set()
|
||||
|
||||
if disabled_names:
|
||||
skills = [s for s in skills if s.name not in disabled_names]
|
||||
|
||||
# Admins bypass department filtering.
|
||||
if not dept_ctx.should_filter:
|
||||
return _serialize_skills(skills)
|
||||
|
||||
# Non-admin: filter by department bindings.
|
||||
db_path = _resolve_db_path(req)
|
||||
all_names = [s.name for s in skills]
|
||||
try:
|
||||
visible_names = await filter_skills_by_department(
|
||||
|
|
@ -218,15 +247,13 @@ async def mention_suggest(q: str = "", req: Request = None):
|
|||
|
||||
if query:
|
||||
skills = [
|
||||
s for s in skills
|
||||
s
|
||||
for s in skills
|
||||
if query in s.name.lower()
|
||||
or (s.config.description and query in s.config.description.lower())
|
||||
]
|
||||
|
||||
return [
|
||||
{"name": s.name, "description": s.config.description or ""}
|
||||
for s in skills[:8]
|
||||
]
|
||||
return [{"name": s.name, "description": s.config.description or ""} for s in skills[:8]]
|
||||
|
||||
|
||||
@router.post("/skills/install")
|
||||
|
|
@ -246,7 +273,9 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
|||
if source and source.startswith("http"):
|
||||
_validate_source_url(source)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30, follow_redirects=True, max_redirects=3) as client:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=30, follow_redirects=True, max_redirects=3
|
||||
) as client:
|
||||
resp = await client.get(source)
|
||||
resp.raise_for_status()
|
||||
yaml_content = resp.text
|
||||
|
|
@ -260,7 +289,9 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
|||
# Verify the path is within the skills directory
|
||||
skills_dir_base = _get_skills_dir(req)
|
||||
if not os.path.realpath(local_path).startswith(os.path.realpath(skills_dir_base)):
|
||||
raise HTTPException(status_code=400, detail="Local file path must be within the skills directory")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Local file path must be within the skills directory"
|
||||
)
|
||||
try:
|
||||
with open(local_path, encoding="utf-8") as f:
|
||||
yaml_content = f.read()
|
||||
|
|
@ -290,12 +321,17 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
|||
# Fallback: try a simpler search
|
||||
search_query2 = f"{skill_name} skill"
|
||||
encoded_query2 = urllib.parse.quote(search_query2)
|
||||
github_api2 = f"https://api.github.com/search/code?q={encoded_query2}+extension:yaml&per_page=5"
|
||||
github_api2 = (
|
||||
f"https://api.github.com/search/code?q={encoded_query2}+extension:yaml&per_page=5"
|
||||
)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
gh_resp2 = await client.get(
|
||||
github_api2,
|
||||
headers={"Accept": "application/vnd.github.v3+json", "User-Agent": "agentkit"},
|
||||
headers={
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"User-Agent": "agentkit",
|
||||
},
|
||||
)
|
||||
items = gh_resp2.json().get("items", [])
|
||||
except Exception:
|
||||
|
|
@ -310,13 +346,19 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
|||
if raw_url:
|
||||
# Validate the URL is from github.com before transforming
|
||||
if not raw_url.startswith("https://github.com/"):
|
||||
raise HTTPException(status_code=400, detail="Search result URL is not from github.com")
|
||||
raw_url = raw_url.replace("github.com", "raw.githubusercontent.com").replace("/blob/", "/")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Search result URL is not from github.com"
|
||||
)
|
||||
raw_url = raw_url.replace("github.com", "raw.githubusercontent.com").replace(
|
||||
"/blob/", "/"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Could not construct download URL")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30, follow_redirects=True, max_redirects=3) as client:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=30, follow_redirects=True, max_redirects=3
|
||||
) as client:
|
||||
resp = await client.get(raw_url)
|
||||
resp.raise_for_status()
|
||||
yaml_content = resp.text
|
||||
|
|
@ -342,6 +384,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
|||
registration_ok = False
|
||||
try:
|
||||
from agentkit.skills.loader import SkillLoader
|
||||
|
||||
loader = SkillLoader(
|
||||
skill_registry=skill_registry,
|
||||
tool_registry=tool_registry,
|
||||
|
|
@ -357,7 +400,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
|||
os.remove(file_path)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(status_code=500, detail=f"Skill downloaded but registration failed")
|
||||
raise HTTPException(status_code=500, detail="Skill downloaded but registration failed")
|
||||
|
||||
return {
|
||||
"status": "installed",
|
||||
|
|
@ -387,7 +430,9 @@ async def uninstall_skill(name: str, req: Request):
|
|||
yaml_path = os.path.join(skills_dir, f"{validated_name}.yaml")
|
||||
|
||||
# Verify resolved path stays within skills_dir
|
||||
if os.path.exists(yaml_path) and os.path.realpath(yaml_path).startswith(os.path.realpath(skills_dir)):
|
||||
if os.path.exists(yaml_path) and os.path.realpath(yaml_path).startswith(
|
||||
os.path.realpath(skills_dir)
|
||||
):
|
||||
os.remove(yaml_path)
|
||||
|
||||
return {"status": "uninstalled", "name": validated_name}
|
||||
|
|
@ -419,8 +464,7 @@ async def create_pipeline(request: CreatePipelineRequest, req: Request):
|
|||
return {
|
||||
"name": pipeline.name,
|
||||
"steps": [
|
||||
{"skill_name": s["skill_name"], "step_index": i}
|
||||
for i, s in enumerate(request.steps)
|
||||
{"skill_name": s["skill_name"], "step_index": i} for i, s in enumerate(request.steps)
|
||||
],
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,522 @@
|
|||
"""Integration tests for the skill/KB admin routes (U6).
|
||||
|
||||
Uses FastAPI TestClient with a test app that mounts the
|
||||
``admin_router`` from ``routes.admin`` plus the public ``skills``
|
||||
router from ``routes.skills`` (to verify disabled-skill filtering).
|
||||
The ``_require_admin`` dependency is overridden via
|
||||
``app.dependency_overrides`` so the tests don't need real JWTs.
|
||||
|
||||
The :class:`SkillRegistry` is a real instance with a test skill
|
||||
loaded from a temp YAML file, so import/reload/update endpoints
|
||||
exercise the real SkillLoader pipeline.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from agentkit.server.admin.kb_service import set_kb_service
|
||||
from agentkit.server.admin.skill_service import set_skill_service
|
||||
from agentkit.server.auth.models import init_auth_db
|
||||
from agentkit.server.routes import admin as admin_routes_module
|
||||
from agentkit.server.routes import skills as skills_routes_module
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_VALID_SKILL_YAML = """\
|
||||
name: admin_test_skill
|
||||
agent_type: simple_generation
|
||||
version: "1.0.0"
|
||||
description: "A test skill for admin route testing"
|
||||
task_mode: llm_generate
|
||||
execution_mode: direct
|
||||
max_steps: 1
|
||||
prompt:
|
||||
identity: "Test"
|
||||
instructions: "Handle test"
|
||||
tools: []
|
||||
"""
|
||||
|
||||
_VALID_SKILL_YAML_2 = """\
|
||||
name: another_test_skill
|
||||
agent_type: simple_generation
|
||||
version: "1.0.0"
|
||||
description: "Another test skill"
|
||||
task_mode: llm_generate
|
||||
execution_mode: direct
|
||||
max_steps: 1
|
||||
prompt:
|
||||
identity: "Test2"
|
||||
instructions: "Handle test 2"
|
||||
tools: []
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def tmp_auth_db(tmp_path: Path) -> Path:
|
||||
db_path = tmp_path / "admin_skill_kb.db"
|
||||
await init_auth_db(db_path)
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def skills_dir(tmp_path: Path) -> str:
|
||||
"""A temp skills directory for YAML files."""
|
||||
d = tmp_path / "skills"
|
||||
d.mkdir()
|
||||
return str(d)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def skill_registry() -> SkillRegistry:
|
||||
return SkillRegistry()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_singletons():
|
||||
"""Reset SkillService and KbService singletons before/after each test."""
|
||||
set_skill_service(None)
|
||||
set_kb_service(None)
|
||||
yield
|
||||
set_skill_service(None)
|
||||
set_kb_service(None)
|
||||
|
||||
|
||||
def _make_admin_user() -> dict[str, Any]:
|
||||
return {"user_id": "admin-1", "username": "admin", "role": "admin"}
|
||||
|
||||
|
||||
def _raise_forbidden() -> dict[str, Any]:
|
||||
"""Dependency override that simulates a non-admin (403) response."""
|
||||
raise HTTPException(status_code=403, detail="Admin permission required")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_app(
|
||||
tmp_auth_db: Path,
|
||||
skills_dir: str,
|
||||
skill_registry: SkillRegistry,
|
||||
) -> FastAPI:
|
||||
"""A minimal FastAPI app with admin + skills routers mounted.
|
||||
|
||||
The ``_require_admin`` dependency is overridden to return a fake
|
||||
admin user. The :class:`SkillRegistry` is set on ``app.state`` so
|
||||
the admin skill endpoints can access it.
|
||||
"""
|
||||
app = FastAPI()
|
||||
app.state.auth_db_path = str(tmp_auth_db)
|
||||
app.state.skill_registry = skill_registry
|
||||
|
||||
# Build a minimal server_config with skill_paths pointing at the
|
||||
# temp skills_dir, so the admin endpoints write YAML there.
|
||||
class _FakeServerConfig:
|
||||
skill_paths = [skills_dir]
|
||||
|
||||
app.state.server_config = _FakeServerConfig()
|
||||
|
||||
app.include_router(admin_routes_module.admin_router, prefix="/api/v1")
|
||||
app.include_router(skills_routes_module.router, prefix="/api/v1")
|
||||
|
||||
# Default: allow admin access.
|
||||
app.dependency_overrides[admin_routes_module._require_admin] = lambda: _make_admin_user()
|
||||
# Override the department context to admin (bypass filtering) so
|
||||
# GET /skills returns all skills (the disabled filter still applies).
|
||||
from agentkit.server.admin.context import DepartmentContext
|
||||
|
||||
app.dependency_overrides[skills_routes_module.get_department_context] = lambda: (
|
||||
DepartmentContext(user_id="admin-1", department_ids=[], is_admin=True)
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_client(admin_app: FastAPI) -> TestClient:
|
||||
return TestClient(admin_app)
|
||||
|
||||
|
||||
def _write_skill_yaml(skills_dir: str, name: str, content: str) -> str:
|
||||
"""Write a skill YAML file to the skills dir and return the path."""
|
||||
path = os.path.join(skills_dir, f"{name}.yaml")
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skill enable/disable
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSkillEnableDisable:
|
||||
def test_disable_skill_returns_200(self, admin_client: TestClient):
|
||||
resp = admin_client.post("/api/v1/admin/skills/admin_test_skill/disable")
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["skill_name"] == "admin_test_skill"
|
||||
assert body["is_disabled"] is True
|
||||
|
||||
def test_disable_then_get_skills_excludes_it(
|
||||
self,
|
||||
admin_client: TestClient,
|
||||
skills_dir: str,
|
||||
skill_registry: SkillRegistry,
|
||||
):
|
||||
# Write a skill YAML and register it.
|
||||
_write_skill_yaml(skills_dir, "admin_test_skill", _VALID_SKILL_YAML)
|
||||
from agentkit.skills.loader import SkillLoader
|
||||
|
||||
SkillLoader(skill_registry=skill_registry).load_from_file(
|
||||
os.path.join(skills_dir, "admin_test_skill.yaml")
|
||||
)
|
||||
|
||||
# Verify the skill is initially listed.
|
||||
resp = admin_client.get("/api/v1/skills")
|
||||
assert resp.status_code == 200
|
||||
names = [s["name"] for s in resp.json()]
|
||||
assert "admin_test_skill" in names
|
||||
|
||||
# Disable the skill.
|
||||
resp = admin_client.post("/api/v1/admin/skills/admin_test_skill/disable")
|
||||
assert resp.status_code == 200
|
||||
|
||||
# GET /skills should now exclude it.
|
||||
resp = admin_client.get("/api/v1/skills")
|
||||
assert resp.status_code == 200
|
||||
names = [s["name"] for s in resp.json()]
|
||||
assert "admin_test_skill" not in names
|
||||
|
||||
def test_enable_skill_returns_200(
|
||||
self,
|
||||
admin_client: TestClient,
|
||||
skills_dir: str,
|
||||
skill_registry: SkillRegistry,
|
||||
):
|
||||
# Write and register the skill, then disable it.
|
||||
_write_skill_yaml(skills_dir, "admin_test_skill", _VALID_SKILL_YAML)
|
||||
from agentkit.skills.loader import SkillLoader
|
||||
|
||||
SkillLoader(skill_registry=skill_registry).load_from_file(
|
||||
os.path.join(skills_dir, "admin_test_skill.yaml")
|
||||
)
|
||||
admin_client.post("/api/v1/admin/skills/admin_test_skill/disable")
|
||||
|
||||
# Enable it.
|
||||
resp = admin_client.post("/api/v1/admin/skills/admin_test_skill/enable")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["enabled"] is True
|
||||
|
||||
# GET /skills should now include it.
|
||||
resp = admin_client.get("/api/v1/skills")
|
||||
names = [s["name"] for s in resp.json()]
|
||||
assert "admin_test_skill" in names
|
||||
|
||||
def test_enable_skill_not_disabled_returns_200(self, admin_client: TestClient):
|
||||
"""Enabling a skill that wasn't disabled returns enabled=False."""
|
||||
resp = admin_client.post("/api/v1/admin/skills/never_disabled/enable")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["enabled"] is False
|
||||
|
||||
def test_disable_skill_invalid_name_returns_400(self, admin_client: TestClient):
|
||||
resp = admin_client.post("/api/v1/admin/skills/Has Spaces/disable")
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_non_admin_cannot_disable_skill(self, admin_app: FastAPI):
|
||||
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||
client = TestClient(admin_app)
|
||||
resp = client.post("/api/v1/admin/skills/some_skill/disable")
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skill import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSkillImport:
|
||||
def test_import_valid_yaml_returns_200(
|
||||
self,
|
||||
admin_client: TestClient,
|
||||
skills_dir: str,
|
||||
skill_registry: SkillRegistry,
|
||||
):
|
||||
resp = admin_client.post(
|
||||
"/api/v1/admin/skills/import",
|
||||
json={"yaml_content": _VALID_SKILL_YAML},
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["name"] == "admin_test_skill"
|
||||
assert os.path.isfile(body["path"])
|
||||
# Skill should be registered.
|
||||
assert skill_registry.has_skill("admin_test_skill")
|
||||
|
||||
def test_import_invalid_yaml_returns_400(self, admin_client: TestClient):
|
||||
resp = admin_client.post(
|
||||
"/api/v1/admin/skills/import",
|
||||
json={"yaml_content": "not: valid: yaml: ["},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_import_missing_name_returns_400(self, admin_client: TestClient):
|
||||
resp = admin_client.post(
|
||||
"/api/v1/admin/skills/import",
|
||||
json={"yaml_content": "agent_type: test\n"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_import_invalid_skill_name_returns_400(self, admin_client: TestClient):
|
||||
bad_yaml = 'name: "Has Spaces"\nagent_type: test\n'
|
||||
resp = admin_client.post(
|
||||
"/api/v1/admin/skills/import",
|
||||
json={"yaml_content": bad_yaml},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_non_admin_cannot_import(self, admin_app: FastAPI):
|
||||
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||
client = TestClient(admin_app)
|
||||
resp = client.post(
|
||||
"/api/v1/admin/skills/import",
|
||||
json={"yaml_content": _VALID_SKILL_YAML},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skill reload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSkillReload:
|
||||
def test_reload_existing_skill_returns_200(
|
||||
self,
|
||||
admin_client: TestClient,
|
||||
skills_dir: str,
|
||||
skill_registry: SkillRegistry,
|
||||
):
|
||||
# Write and register the skill first.
|
||||
_write_skill_yaml(skills_dir, "admin_test_skill", _VALID_SKILL_YAML)
|
||||
from agentkit.skills.loader import SkillLoader
|
||||
|
||||
SkillLoader(skill_registry=skill_registry).load_from_file(
|
||||
os.path.join(skills_dir, "admin_test_skill.yaml")
|
||||
)
|
||||
|
||||
resp = admin_client.post("/api/v1/admin/skills/admin_test_skill/reload")
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["name"] == "admin_test_skill"
|
||||
assert body["status"] == "reloaded"
|
||||
|
||||
def test_reload_nonexistent_skill_returns_404(self, admin_client: TestClient):
|
||||
resp = admin_client.post("/api/v1/admin/skills/nonexistent/reload")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skill update (PATCH)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSkillUpdate:
|
||||
def test_update_skill_returns_200(
|
||||
self,
|
||||
admin_client: TestClient,
|
||||
skills_dir: str,
|
||||
skill_registry: SkillRegistry,
|
||||
):
|
||||
# Write and register the skill first.
|
||||
_write_skill_yaml(skills_dir, "admin_test_skill", _VALID_SKILL_YAML)
|
||||
from agentkit.skills.loader import SkillLoader
|
||||
|
||||
SkillLoader(skill_registry=skill_registry).load_from_file(
|
||||
os.path.join(skills_dir, "admin_test_skill.yaml")
|
||||
)
|
||||
|
||||
resp = admin_client.patch(
|
||||
"/api/v1/admin/skills/admin_test_skill",
|
||||
json={"config": {"description": "Updated via admin API"}},
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["status"] == "updated"
|
||||
|
||||
# Verify the YAML file was updated.
|
||||
import yaml
|
||||
|
||||
with open(body["path"], encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
assert data["description"] == "Updated via admin API"
|
||||
|
||||
def test_update_nonexistent_skill_returns_404(self, admin_client: TestClient):
|
||||
resp = admin_client.patch(
|
||||
"/api/v1/admin/skills/nonexistent",
|
||||
json={"config": {"description": "x"}},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# KB document endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKbDocumentRoutes:
|
||||
def test_list_documents_returns_200(self, admin_client: TestClient):
|
||||
resp = admin_client.get("/api/v1/admin/kb/documents")
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert "documents" in body
|
||||
assert isinstance(body["documents"], list)
|
||||
|
||||
def test_upload_document_returns_201_with_department_id(self, admin_client: TestClient):
|
||||
dept_id = str(uuid.uuid4())
|
||||
resp = admin_client.post(
|
||||
"/api/v1/admin/kb/documents",
|
||||
json={
|
||||
"filename": "test.txt",
|
||||
"content": "hello world",
|
||||
"department_id": dept_id,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201, resp.text
|
||||
body = resp.json()
|
||||
assert body["filename"] == "test.txt"
|
||||
assert body["department_id"] == dept_id
|
||||
assert "document_id" in body
|
||||
|
||||
def test_upload_document_without_department_id(self, admin_client: TestClient):
|
||||
resp = admin_client.post(
|
||||
"/api/v1/admin/kb/documents",
|
||||
json={"filename": "global.txt", "content": "global content"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
body = resp.json()
|
||||
assert body["department_id"] is None
|
||||
|
||||
def test_upload_then_delete_document(self, admin_client: TestClient):
|
||||
# Upload
|
||||
resp = admin_client.post(
|
||||
"/api/v1/admin/kb/documents",
|
||||
json={"filename": "delete_me.txt", "content": "x"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
doc_id = resp.json()["document_id"]
|
||||
|
||||
# Delete
|
||||
resp = admin_client.delete(f"/api/v1/admin/kb/documents/{doc_id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"deleted": True}
|
||||
|
||||
# Second delete should 404.
|
||||
resp = admin_client.delete(f"/api/v1/admin/kb/documents/{doc_id}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_delete_nonexistent_document_returns_404(self, admin_client: TestClient):
|
||||
resp = admin_client.delete("/api/v1/admin/kb/documents/nonexistent-id")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_list_documents_filters_by_department_id(self, admin_client: TestClient):
|
||||
dept_a = str(uuid.uuid4())
|
||||
dept_b = str(uuid.uuid4())
|
||||
|
||||
# Upload 3 docs: one for dept_a, one for dept_b, one global.
|
||||
admin_client.post(
|
||||
"/api/v1/admin/kb/documents",
|
||||
json={"filename": "a.txt", "content": "a", "department_id": dept_a},
|
||||
)
|
||||
admin_client.post(
|
||||
"/api/v1/admin/kb/documents",
|
||||
json={"filename": "b.txt", "content": "b", "department_id": dept_b},
|
||||
)
|
||||
admin_client.post(
|
||||
"/api/v1/admin/kb/documents",
|
||||
json={"filename": "global.txt", "content": "g"},
|
||||
)
|
||||
|
||||
# Filter by dept_a: should see dept_a's doc + global doc.
|
||||
resp = admin_client.get("/api/v1/admin/kb/documents", params={"department_id": dept_a})
|
||||
assert resp.status_code == 200
|
||||
docs = resp.json()["documents"]
|
||||
dept_ids = {d["department_id"] for d in docs}
|
||||
assert dept_a in dept_ids
|
||||
assert None in dept_ids # global doc
|
||||
assert dept_b not in dept_ids
|
||||
|
||||
def test_non_admin_cannot_list_documents(self, admin_app: FastAPI):
|
||||
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||
client = TestClient(admin_app)
|
||||
resp = client.get("/api/v1/admin/kb/documents")
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_non_admin_cannot_upload_document(self, admin_app: FastAPI):
|
||||
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||
client = TestClient(admin_app)
|
||||
resp = client.post(
|
||||
"/api/v1/admin/kb/documents",
|
||||
json={"filename": "x.txt", "content": "x"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# KB source sync/rebuild
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKbSourceRoutes:
|
||||
def test_sync_nonexistent_source_returns_404(self, admin_client: TestClient):
|
||||
resp = admin_client.post("/api/v1/admin/kb/sources/nonexistent/sync")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_rebuild_nonexistent_source_returns_404(self, admin_client: TestClient):
|
||||
resp = admin_client.post("/api/v1/admin/kb/sources/nonexistent/rebuild")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_sync_existing_source_returns_200(self, admin_client: TestClient):
|
||||
# Add a source via the KbService (bypassing the route layer).
|
||||
from agentkit.server.admin.kb_service import get_kb_service
|
||||
|
||||
svc = get_kb_service()
|
||||
store = svc._resolve_store()
|
||||
source = store.add_source("Test Source", "local", {})
|
||||
|
||||
resp = admin_client.post(f"/api/v1/admin/kb/sources/{source.id}/sync")
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["status"] == "syncing"
|
||||
|
||||
def test_rebuild_existing_source_returns_200(self, admin_client: TestClient):
|
||||
from agentkit.server.admin.kb_service import get_kb_service
|
||||
|
||||
svc = get_kb_service()
|
||||
store = svc._resolve_store()
|
||||
source = store.add_source("Test Source", "local", {})
|
||||
|
||||
resp = admin_client.post(f"/api/v1/admin/kb/sources/{source.id}/rebuild")
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["status"] == "rebuilding"
|
||||
|
||||
def test_non_admin_cannot_sync_source(self, admin_app: FastAPI):
|
||||
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||
client = TestClient(admin_app)
|
||||
resp = client.post("/api/v1/admin/kb/sources/any/sync")
|
||||
assert resp.status_code == 403
|
||||
|
|
@ -4,7 +4,7 @@ Covers:
|
|||
- ``init_auth_db`` creates the new V3 tables (departments, user_departments,
|
||||
department_skill_bindings, department_kb_bindings, department_quotas)
|
||||
- ``init_auth_db`` is idempotent (calling twice does not error)
|
||||
- ``_SCHEMA_VERSION`` is recorded as 3 in ``auth_meta``
|
||||
- ``_SCHEMA_VERSION`` is recorded as 4 in ``auth_meta`` (V4 adds skill_states)
|
||||
- ``departments`` insert + query round-trip
|
||||
- ``user_departments`` many-to-many relationship (one user → many departments,
|
||||
one department → many users)
|
||||
|
|
@ -109,9 +109,7 @@ async def _list_index_names(db: aiosqlite.Connection, table: str) -> set[str]:
|
|||
|
||||
async def _list_table_names(db: aiosqlite.Connection) -> set[str]:
|
||||
"""Return the set of table names in the SQLite file."""
|
||||
cursor = await db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||
)
|
||||
cursor = await db.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
rows = await cursor.fetchall()
|
||||
return {row[0] for row in rows}
|
||||
|
||||
|
|
@ -122,9 +120,9 @@ async def _list_table_names(db: aiosqlite.Connection) -> set[str]:
|
|||
|
||||
|
||||
class TestSchemaVersion:
|
||||
def test_schema_version_is_v3(self):
|
||||
"""V3 adds the department-scoped admin tables."""
|
||||
assert _SCHEMA_VERSION == 3
|
||||
def test_schema_version_is_v4(self):
|
||||
"""V4 adds the skill_states table (U6 — Admin Console)."""
|
||||
assert _SCHEMA_VERSION == 4
|
||||
|
||||
def test_sqlalchemy_model_table_names(self):
|
||||
assert DepartmentModel.__tablename__ == "departments"
|
||||
|
|
@ -165,13 +163,13 @@ class TestInitAuthDbTables:
|
|||
tables = await _list_table_names(db)
|
||||
assert "department_quotas" in tables
|
||||
|
||||
async def test_records_schema_version_3_in_auth_meta(self, fresh_db: Path):
|
||||
async def test_records_schema_version_4_in_auth_meta(self, fresh_db: Path):
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT value FROM auth_meta WHERE key='schema_version'")
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["value"] == "3"
|
||||
assert row["value"] == "4"
|
||||
assert row["value"] == str(_SCHEMA_VERSION)
|
||||
|
||||
async def test_init_auth_db_is_idempotent(self, tmp_path: Path):
|
||||
|
|
@ -279,9 +277,7 @@ class TestDepartmentsCrud:
|
|||
)
|
||||
await db.commit()
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"SELECT description FROM departments WHERE id=?", (dept_id,)
|
||||
)
|
||||
cursor = await db.execute("SELECT description FROM departments WHERE id=?", (dept_id,))
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["description"] is None
|
||||
|
|
@ -310,8 +306,7 @@ class TestUserDepartmentsManyToMany:
|
|||
await db.commit()
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"SELECT department_id FROM user_departments WHERE user_id=? "
|
||||
"ORDER BY department_id",
|
||||
"SELECT department_id FROM user_departments WHERE user_id=? ORDER BY department_id",
|
||||
(user_id,),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
|
|
@ -336,8 +331,7 @@ class TestUserDepartmentsManyToMany:
|
|||
await db.commit()
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"SELECT user_id FROM user_departments WHERE department_id=? "
|
||||
"ORDER BY user_id",
|
||||
"SELECT user_id FROM user_departments WHERE department_id=? ORDER BY user_id",
|
||||
(dept_id,),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
|
|
@ -394,9 +388,7 @@ class TestDepartmentSkillBindingsUnique:
|
|||
)
|
||||
await db.commit()
|
||||
|
||||
async def test_same_skill_name_in_different_departments_is_allowed(
|
||||
self, fresh_db: Path
|
||||
):
|
||||
async def test_same_skill_name_in_different_departments_is_allowed(self, fresh_db: Path):
|
||||
dept_a = str(uuid.uuid4())
|
||||
dept_b = str(uuid.uuid4())
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,386 @@
|
|||
"""Unit tests for SkillService (U6 — skill enable/disable + import/reload).
|
||||
|
||||
Covers:
|
||||
- disable_skill → is_skill_disabled returns True
|
||||
- enable_skill → is_skill_disabled returns False
|
||||
- list_disabled_skills → returns correct list
|
||||
- import_skill with valid YAML → file written, skill loaded
|
||||
- import_skill with invalid YAML → ValueError
|
||||
- import_skill with path traversal attempt → ValueError
|
||||
- reload_skill → unregister + reload
|
||||
- update_skill_config → YAML updated + reloaded
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.server.admin.skill_service import SkillService, _validate_skill_name
|
||||
from agentkit.server.auth.models import init_auth_db
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def fresh_db(tmp_path: Path) -> Path:
|
||||
"""A brand-new auth DB on a fresh path (no data)."""
|
||||
db_path = tmp_path / "auth.db"
|
||||
await init_auth_db(db_path)
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service() -> SkillService:
|
||||
return SkillService()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def skills_dir(tmp_path: Path) -> str:
|
||||
"""A temp skills directory for YAML files."""
|
||||
d = tmp_path / "skills"
|
||||
d.mkdir()
|
||||
return str(d)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def skill_registry() -> SkillRegistry:
|
||||
return SkillRegistry()
|
||||
|
||||
|
||||
_VALID_SKILL_YAML = """\
|
||||
name: test_skill
|
||||
agent_type: simple_generation
|
||||
version: "1.0.0"
|
||||
description: "A test skill for unit testing"
|
||||
task_mode: llm_generate
|
||||
execution_mode: direct
|
||||
max_steps: 1
|
||||
prompt:
|
||||
identity: "Test"
|
||||
instructions: "Handle test"
|
||||
tools: []
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enable / disable
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDisableEnable:
|
||||
async def test_disable_skill_marks_disabled(self, service: SkillService, fresh_db: Path):
|
||||
result = await service.disable_skill(fresh_db, "test_skill", disabled_by="admin-1")
|
||||
assert result["skill_name"] == "test_skill"
|
||||
assert result["is_disabled"] is True
|
||||
assert result["disabled_by"] == "admin-1"
|
||||
assert "disabled_at" in result
|
||||
|
||||
assert await service.is_skill_disabled(fresh_db, "test_skill") is True
|
||||
|
||||
async def test_disable_skill_normalizes_name(self, service: SkillService, fresh_db: Path):
|
||||
"""Skill names are normalized to lowercase before storage."""
|
||||
await service.disable_skill(fresh_db, "TestSkill")
|
||||
assert await service.is_skill_disabled(fresh_db, "testskill") is True
|
||||
assert await service.is_skill_disabled(fresh_db, "TestSkill") is True
|
||||
|
||||
async def test_disable_skill_is_idempotent(self, service: SkillService, fresh_db: Path):
|
||||
"""Disabling an already-disabled skill updates the row, not duplicates."""
|
||||
await service.disable_skill(fresh_db, "skill_a", disabled_by="admin-1")
|
||||
await service.disable_skill(fresh_db, "skill_a", disabled_by="admin-2")
|
||||
disabled = await service.list_disabled_skills(fresh_db)
|
||||
assert disabled == ["skill_a"]
|
||||
|
||||
async def test_enable_skill_removes_disabled_mark(self, service: SkillService, fresh_db: Path):
|
||||
await service.disable_skill(fresh_db, "skill_b")
|
||||
assert await service.is_skill_disabled(fresh_db, "skill_b") is True
|
||||
|
||||
enabled = await service.enable_skill(fresh_db, "skill_b")
|
||||
assert enabled is True
|
||||
assert await service.is_skill_disabled(fresh_db, "skill_b") is False
|
||||
|
||||
async def test_enable_skill_returns_false_if_not_disabled(
|
||||
self, service: SkillService, fresh_db: Path
|
||||
):
|
||||
"""Enabling a skill that wasn't disabled returns False (no-op)."""
|
||||
enabled = await service.enable_skill(fresh_db, "never_disabled")
|
||||
assert enabled is False
|
||||
|
||||
async def test_is_skill_disabled_returns_false_for_unknown(
|
||||
self, service: SkillService, fresh_db: Path
|
||||
):
|
||||
assert await service.is_skill_disabled(fresh_db, "unknown_skill") is False
|
||||
|
||||
async def test_list_disabled_skills_returns_sorted_list(
|
||||
self, service: SkillService, fresh_db: Path
|
||||
):
|
||||
await service.disable_skill(fresh_db, "charlie")
|
||||
await service.disable_skill(fresh_db, "alpha")
|
||||
await service.disable_skill(fresh_db, "bravo")
|
||||
result = await service.list_disabled_skills(fresh_db)
|
||||
assert result == ["alpha", "bravo", "charlie"]
|
||||
|
||||
async def test_list_disabled_skills_empty_when_none_disabled(
|
||||
self, service: SkillService, fresh_db: Path
|
||||
):
|
||||
result = await service.list_disabled_skills(fresh_db)
|
||||
assert result == []
|
||||
|
||||
async def test_disable_skill_invalid_name_raises(self, service: SkillService, fresh_db: Path):
|
||||
with pytest.raises(ValueError, match="Invalid skill name"):
|
||||
await service.disable_skill(fresh_db, "Invalid Name With Spaces")
|
||||
|
||||
async def test_disable_skill_uppercase_name_normalizes(
|
||||
self, service: SkillService, fresh_db: Path
|
||||
):
|
||||
"""Uppercase names are normalized to lowercase (regex requires lowercase)."""
|
||||
# 'TEST_SKILL' normalizes to 'test_skill' which matches the regex.
|
||||
await service.disable_skill(fresh_db, "TEST_SKILL")
|
||||
assert await service.is_skill_disabled(fresh_db, "test_skill") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# import_skill
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestImportSkill:
|
||||
async def test_import_skill_writes_file_and_loads(
|
||||
self,
|
||||
service: SkillService,
|
||||
skills_dir: str,
|
||||
skill_registry: SkillRegistry,
|
||||
):
|
||||
result = await service.import_skill(
|
||||
_VALID_SKILL_YAML,
|
||||
skills_dir,
|
||||
skill_registry=skill_registry,
|
||||
)
|
||||
assert result["name"] == "test_skill"
|
||||
assert Path(result["path"]).is_file()
|
||||
# Skill should be registered in the registry.
|
||||
assert skill_registry.has_skill("test_skill")
|
||||
|
||||
async def test_import_skill_without_registry_writes_file_only(
|
||||
self,
|
||||
service: SkillService,
|
||||
skills_dir: str,
|
||||
):
|
||||
result = await service.import_skill(_VALID_SKILL_YAML, skills_dir)
|
||||
assert result["name"] == "test_skill"
|
||||
assert Path(result["path"]).is_file()
|
||||
|
||||
async def test_import_skill_invalid_yaml_raises(
|
||||
self,
|
||||
service: SkillService,
|
||||
skills_dir: str,
|
||||
):
|
||||
with pytest.raises(ValueError, match="Invalid YAML"):
|
||||
await service.import_skill("not: valid: yaml: [", skills_dir)
|
||||
|
||||
async def test_import_skill_non_mapping_yaml_raises(
|
||||
self,
|
||||
service: SkillService,
|
||||
skills_dir: str,
|
||||
):
|
||||
with pytest.raises(ValueError, match="mapping"):
|
||||
await service.import_skill("- just\n- a\n- list\n", skills_dir)
|
||||
|
||||
async def test_import_skill_missing_name_field_raises(
|
||||
self,
|
||||
service: SkillService,
|
||||
skills_dir: str,
|
||||
):
|
||||
yaml_without_name = "agent_type: simple_generation\n"
|
||||
with pytest.raises(ValueError, match="name"):
|
||||
await service.import_skill(yaml_without_name, skills_dir)
|
||||
|
||||
async def test_import_skill_invalid_name_raises(
|
||||
self,
|
||||
service: SkillService,
|
||||
skills_dir: str,
|
||||
):
|
||||
"""YAML with a name that fails the regex raises ValueError."""
|
||||
bad_yaml = 'name: "Bad Name With Spaces"\nagent_type: test\n'
|
||||
with pytest.raises(ValueError, match="Invalid skill name"):
|
||||
await service.import_skill(bad_yaml, skills_dir)
|
||||
|
||||
async def test_import_skill_path_traversal_blocked(
|
||||
self,
|
||||
service: SkillService,
|
||||
skills_dir: str,
|
||||
):
|
||||
"""A YAML with a name containing path separators is rejected by the regex."""
|
||||
# The regex `^[a-z0-9][a-z0-9_-]{0,63}$` rejects '/' and '..',
|
||||
# so path traversal via the name field is impossible. We verify
|
||||
# this by attempting to import a YAML with a traversal name.
|
||||
traversal_yaml = 'name: "../etc/passwd"\nagent_type: test\n'
|
||||
with pytest.raises(ValueError, match="Invalid skill name"):
|
||||
await service.import_skill(traversal_yaml, skills_dir)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reload_skill
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReloadSkill:
|
||||
async def test_reload_skill_unregisters_and_reloads(
|
||||
self,
|
||||
service: SkillService,
|
||||
skills_dir: str,
|
||||
skill_registry: SkillRegistry,
|
||||
):
|
||||
# First import the skill.
|
||||
await service.import_skill(
|
||||
_VALID_SKILL_YAML,
|
||||
skills_dir,
|
||||
skill_registry=skill_registry,
|
||||
)
|
||||
assert skill_registry.has_skill("test_skill")
|
||||
|
||||
# Now reload it.
|
||||
result = await service.reload_skill("test_skill", skill_registry, skills_dir)
|
||||
assert result["name"] == "test_skill"
|
||||
assert result["status"] == "reloaded"
|
||||
assert skill_registry.has_skill("test_skill")
|
||||
|
||||
async def test_reload_skill_missing_yaml_raises(
|
||||
self,
|
||||
service: SkillService,
|
||||
skills_dir: str,
|
||||
skill_registry: SkillRegistry,
|
||||
):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await service.reload_skill("nonexistent", skill_registry, skills_dir)
|
||||
|
||||
async def test_reload_skill_invalid_name_raises(
|
||||
self,
|
||||
service: SkillService,
|
||||
skills_dir: str,
|
||||
skill_registry: SkillRegistry,
|
||||
):
|
||||
with pytest.raises(ValueError, match="Invalid skill name"):
|
||||
await service.reload_skill("Bad Name", skill_registry, skills_dir)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# update_skill_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpdateSkillConfig:
|
||||
async def test_update_skill_config_updates_yaml_and_reloads(
|
||||
self,
|
||||
service: SkillService,
|
||||
skills_dir: str,
|
||||
skill_registry: SkillRegistry,
|
||||
):
|
||||
# First import the skill.
|
||||
await service.import_skill(
|
||||
_VALID_SKILL_YAML,
|
||||
skills_dir,
|
||||
skill_registry=skill_registry,
|
||||
)
|
||||
|
||||
# Patch the description.
|
||||
result = await service.update_skill_config(
|
||||
"test_skill",
|
||||
{"description": "Updated description"},
|
||||
skills_dir,
|
||||
skill_registry,
|
||||
)
|
||||
assert result["name"] == "test_skill"
|
||||
assert result["status"] == "updated"
|
||||
|
||||
# Verify the YAML file was updated.
|
||||
import yaml
|
||||
|
||||
with open(result["path"], encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
assert data["description"] == "Updated description"
|
||||
# Original fields should be preserved.
|
||||
assert data["name"] == "test_skill"
|
||||
assert data["agent_type"] == "simple_generation"
|
||||
|
||||
# Skill should still be registered.
|
||||
assert skill_registry.has_skill("test_skill")
|
||||
|
||||
async def test_update_skill_config_missing_yaml_raises(
|
||||
self,
|
||||
service: SkillService,
|
||||
skills_dir: str,
|
||||
skill_registry: SkillRegistry,
|
||||
):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await service.update_skill_config(
|
||||
"nonexistent",
|
||||
{"description": "x"},
|
||||
skills_dir,
|
||||
skill_registry,
|
||||
)
|
||||
|
||||
async def test_update_skill_config_preserves_name(
|
||||
self,
|
||||
service: SkillService,
|
||||
skills_dir: str,
|
||||
skill_registry: SkillRegistry,
|
||||
):
|
||||
"""Patching the 'name' field is ignored — name is preserved."""
|
||||
await service.import_skill(
|
||||
_VALID_SKILL_YAML,
|
||||
skills_dir,
|
||||
skill_registry=skill_registry,
|
||||
)
|
||||
|
||||
result = await service.update_skill_config(
|
||||
"test_skill",
|
||||
{"name": "should_be_ignored", "description": "new"},
|
||||
skills_dir,
|
||||
skill_registry,
|
||||
)
|
||||
import yaml
|
||||
|
||||
with open(result["path"], encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
# Name should be preserved (not renamed).
|
||||
assert data["name"] == "test_skill"
|
||||
assert data["description"] == "new"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_skill_name helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateSkillName:
|
||||
def test_valid_name_returns_normalized(self):
|
||||
assert _validate_skill_name("test_skill") == "test_skill"
|
||||
assert _validate_skill_name("TestSkill") == "testskill"
|
||||
assert _validate_skill_name(" spaced ") == "spaced"
|
||||
assert _validate_skill_name("a-b_c-1") == "a-b_c-1"
|
||||
|
||||
def test_invalid_name_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
_validate_skill_name("Has Spaces")
|
||||
with pytest.raises(ValueError):
|
||||
_validate_skill_name("")
|
||||
with pytest.raises(ValueError):
|
||||
_validate_skill_name("-leading-dash")
|
||||
with pytest.raises(ValueError):
|
||||
_validate_skill_name("../traversal")
|
||||
with pytest.raises(ValueError):
|
||||
_validate_skill_name("has.dot")
|
||||
with pytest.raises(ValueError):
|
||||
_validate_skill_name("has/slash")
|
||||
|
||||
def test_non_string_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
_validate_skill_name(None) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError):
|
||||
_validate_skill_name(123) # type: ignore[arg-type]
|
||||
|
|
@ -118,9 +118,9 @@ async def _list_index_names(db: aiosqlite.Connection, table: str) -> set[str]:
|
|||
|
||||
|
||||
class TestSchemaVersion:
|
||||
def test_schema_version_is_v3(self):
|
||||
"""The current schema version is 3 (V3 adds department-scoped admin tables)."""
|
||||
assert _SCHEMA_VERSION == 3
|
||||
def test_schema_version_is_v4(self):
|
||||
"""The current schema version is 4 (V4 adds skill_states table)."""
|
||||
assert _SCHEMA_VERSION == 4
|
||||
|
||||
def test_sqlalchemy_model_table_name(self):
|
||||
assert AuthSessionModel.__tablename__ == "auth_sessions"
|
||||
|
|
|
|||
|
|
@ -2,13 +2,12 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.server.app import create_app
|
||||
from agentkit.server.config import ServerConfig
|
||||
from agentkit.skills.base import Skill, SkillConfig
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
|
@ -202,7 +201,31 @@ class TestListCapabilities:
|
|||
|
||||
|
||||
class TestReloadSkill:
|
||||
def test_reload_skill(self, client, skill_registry):
|
||||
def test_reload_skill(self, client, skill_registry, tmp_path):
|
||||
# Write a valid YAML file for "reload_skill" to a temp skills dir.
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
yaml_content = (
|
||||
"name: reload_skill\n"
|
||||
'agent_type: "test_type"\n'
|
||||
'version: "1.0.0"\n'
|
||||
'description: "Skill for reload testing"\n'
|
||||
"task_mode: llm_generate\n"
|
||||
"execution_mode: direct\n"
|
||||
"max_steps: 1\n"
|
||||
"intent:\n"
|
||||
' keywords: ["reload"]\n'
|
||||
' description: "reload test"\n'
|
||||
"prompt:\n"
|
||||
' identity: "reload tester"\n'
|
||||
' instructions: "handle reload"\n'
|
||||
"tools: []\n"
|
||||
)
|
||||
(skills_dir / "reload_skill.yaml").write_text(yaml_content, encoding="utf-8")
|
||||
|
||||
# Point the app at the temp skills dir and register the skill so the
|
||||
# 404 guard in the route passes.
|
||||
client.app.state.server_config = ServerConfig(skill_paths=[str(skills_dir)])
|
||||
_register_skill(skill_registry, "reload_skill")
|
||||
|
||||
response = client.post("/api/v1/skill-management/skills/reload_skill/reload")
|
||||
|
|
|
|||
Loading…
Reference in New Issue