feat(admin): U6 — Skill & KB management endpoints + department binding
SkillService: enable/disable (persisted in skill_states table, schema v4), import from YAML (with path traversal + name validation), reload from file, update config. GET /skills now filters disabled skills. KbService: list/upload/delete documents with department_id binding. Added department_id field to KnowledgeSource + UploadedDocument. Department visibility: (bound to user depts) ∪ (global = None). 10 new admin endpoints: skill enable/disable/import/reload/update, KB documents CRUD, source sync/rebuild. All guarded by _require_admin. Implemented reload stub in skill_management.py (was no-op). 54 new tests (26 unit + 28 integration). Fixed 4 pre-existing lint errors. 357 admin tests pass, no regressions.
This commit is contained in:
parent
980919fc95
commit
fd7f6816b8
|
|
@ -0,0 +1,246 @@
|
||||||
|
"""KbService — admin-driven KB document/source management (U6).
|
||||||
|
|
||||||
|
This module wraps the existing in-memory :class:`KnowledgeSourceStore`
|
||||||
|
from :mod:`agentkit.server.routes.kb_management` so admin routes can
|
||||||
|
manage documents and sources through a service layer (mirroring
|
||||||
|
:class:`DepartmentService` and :class:`SkillService`).
|
||||||
|
|
||||||
|
Department filtering
|
||||||
|
--------------------
|
||||||
|
When listing documents, the service filters by department visibility:
|
||||||
|
- documents whose ``department_id`` is in the caller's department set, OR
|
||||||
|
- documents whose ``department_id`` is ``None`` (global documents).
|
||||||
|
|
||||||
|
Admin callers (``department_ids=None``) bypass filtering and see all
|
||||||
|
documents.
|
||||||
|
|
||||||
|
The service is a module-level singleton (see :func:`get_kb_service`)
|
||||||
|
so tests can inject a custom instance via :func:`set_kb_service`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.server.routes.kb_management import (
|
||||||
|
KnowledgeSource,
|
||||||
|
KnowledgeSourceStore,
|
||||||
|
UploadedDocument,
|
||||||
|
get_source_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _now_iso() -> str:
|
||||||
|
"""Return current UTC time as ISO 8601 string."""
|
||||||
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
class KbService:
|
||||||
|
"""Wraps :class:`KnowledgeSourceStore` with admin-friendly operations.
|
||||||
|
|
||||||
|
The service does NOT own the store — it delegates to the
|
||||||
|
module-level singleton from :mod:`kb_management`. This keeps a
|
||||||
|
single source of truth for KB state (the existing routes still
|
||||||
|
work) while giving admin routes a clean service boundary.
|
||||||
|
|
||||||
|
Department filtering is applied at the service layer so admin
|
||||||
|
routes can pass ``department_ids`` directly from the
|
||||||
|
:class:`DepartmentContext` dependency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, store: KnowledgeSourceStore | None = None) -> None:
|
||||||
|
self._store = store
|
||||||
|
|
||||||
|
def _resolve_store(self) -> KnowledgeSourceStore:
|
||||||
|
"""Return the injected store, or the module singleton."""
|
||||||
|
if self._store is not None:
|
||||||
|
return self._store
|
||||||
|
return get_source_store()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Document operations
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def list_documents(
|
||||||
|
self,
|
||||||
|
department_ids: list[str] | None = None,
|
||||||
|
source_id: str | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""List documents, optionally filtered by source and department.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
department_ids: Caller's department ids. ``None`` means admin
|
||||||
|
(no filtering — return all documents). An empty list means
|
||||||
|
a user with no departments (return only global documents).
|
||||||
|
source_id: Optional source id filter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of document dicts (JSON-safe).
|
||||||
|
"""
|
||||||
|
store = self._resolve_store()
|
||||||
|
documents = store.list_documents(source_id=source_id)
|
||||||
|
|
||||||
|
# Admin bypass: department_ids=None means "see everything".
|
||||||
|
if department_ids is None:
|
||||||
|
return [self._doc_to_dict(d) for d in documents]
|
||||||
|
|
||||||
|
# Non-admin: visible = (docs in user's departments) ∪ (global docs).
|
||||||
|
dept_set = set(department_ids)
|
||||||
|
visible = [d for d in documents if d.department_id is None or d.department_id in dept_set]
|
||||||
|
return [self._doc_to_dict(d) for d in visible]
|
||||||
|
|
||||||
|
def upload_document(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
content: bytes,
|
||||||
|
source_id: str = "",
|
||||||
|
department_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Upload a document with optional department binding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Original filename.
|
||||||
|
content: File content bytes (used to estimate chunk count).
|
||||||
|
source_id: Optional source id. Defaults to ``"local"``.
|
||||||
|
department_id: Optional department id to bind the document to.
|
||||||
|
``None`` means the document is global (visible to all).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The uploaded document dict.
|
||||||
|
"""
|
||||||
|
store = self._resolve_store()
|
||||||
|
effective_source_id = source_id or "local"
|
||||||
|
|
||||||
|
# Ensure a local source exists for default uploads.
|
||||||
|
if effective_source_id == "local":
|
||||||
|
local_sources = [s for s in store.list_sources() if s.type == "local"]
|
||||||
|
if not local_sources:
|
||||||
|
store.add_source("本地文档", "local", {})
|
||||||
|
|
||||||
|
# Estimate chunks from content length (rough approximation,
|
||||||
|
# mirrors the existing /kb-management/documents/upload route).
|
||||||
|
chunks = max(1, len(content) // 500)
|
||||||
|
|
||||||
|
doc = UploadedDocument(
|
||||||
|
document_id=str(uuid.uuid4()),
|
||||||
|
filename=filename,
|
||||||
|
source_id=effective_source_id,
|
||||||
|
chunks=chunks,
|
||||||
|
status="indexed",
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
store.add_document(doc)
|
||||||
|
return self._doc_to_dict(doc)
|
||||||
|
|
||||||
|
def delete_document(self, document_id: str) -> bool:
|
||||||
|
"""Delete a document by id. Returns ``True`` if deleted."""
|
||||||
|
store = self._resolve_store()
|
||||||
|
return store.delete_document(document_id)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Source operations
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def list_sources(self) -> list[dict[str, Any]]:
|
||||||
|
"""List all KB sources."""
|
||||||
|
store = self._resolve_store()
|
||||||
|
return [self._source_to_dict(s) for s in store.list_sources()]
|
||||||
|
|
||||||
|
def get_source(self, source_id: str) -> dict[str, Any] | None:
|
||||||
|
"""Return a single source by id, or ``None`` if not found."""
|
||||||
|
store = self._resolve_store()
|
||||||
|
source = store.get_source(source_id)
|
||||||
|
return self._source_to_dict(source) if source else None
|
||||||
|
|
||||||
|
def sync_source(self, source_id: str) -> dict[str, Any]:
|
||||||
|
"""Trigger a source sync (stub — marks status as ``syncing``).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the source does not exist.
|
||||||
|
"""
|
||||||
|
store = self._resolve_store()
|
||||||
|
source = store.get_source(source_id)
|
||||||
|
if source is None:
|
||||||
|
raise ValueError(f"KB source {source_id!r} not found")
|
||||||
|
# Stub: mark status as syncing. A real implementation would
|
||||||
|
# kick off an async sync job and return a job id.
|
||||||
|
source.status = "syncing"
|
||||||
|
source.last_synced = _now_iso()
|
||||||
|
return {
|
||||||
|
"source_id": source_id,
|
||||||
|
"status": "syncing",
|
||||||
|
"message": "Sync started",
|
||||||
|
}
|
||||||
|
|
||||||
|
def rebuild_index(self, source_id: str) -> dict[str, Any]:
|
||||||
|
"""Rebuild the index for a source (stub — marks status).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the source does not exist.
|
||||||
|
"""
|
||||||
|
store = self._resolve_store()
|
||||||
|
source = store.get_source(source_id)
|
||||||
|
if source is None:
|
||||||
|
raise ValueError(f"KB source {source_id!r} not found")
|
||||||
|
# Stub: mark status as rebuilding.
|
||||||
|
source.status = "rebuilding"
|
||||||
|
return {
|
||||||
|
"source_id": source_id,
|
||||||
|
"status": "rebuilding",
|
||||||
|
"message": "Index rebuild started",
|
||||||
|
}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Serialization helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _doc_to_dict(self, doc: UploadedDocument) -> dict[str, Any]:
|
||||||
|
"""Convert an :class:`UploadedDocument` to a JSON-safe dict."""
|
||||||
|
return {
|
||||||
|
"document_id": doc.document_id,
|
||||||
|
"filename": doc.filename,
|
||||||
|
"source_id": doc.source_id,
|
||||||
|
"chunks": doc.chunks,
|
||||||
|
"status": doc.status,
|
||||||
|
"created_at": doc.created_at,
|
||||||
|
"department_id": doc.department_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _source_to_dict(self, source: KnowledgeSource) -> dict[str, Any]:
|
||||||
|
"""Convert a :class:`KnowledgeSource` to a JSON-safe dict."""
|
||||||
|
return {
|
||||||
|
"id": source.id,
|
||||||
|
"name": source.name,
|
||||||
|
"type": source.type,
|
||||||
|
"status": source.status,
|
||||||
|
"document_count": source.document_count,
|
||||||
|
"last_synced": source.last_synced,
|
||||||
|
"department_id": source.department_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Module-level singleton (overridable in tests via set_kb_service)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
_kb_service: KbService | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_kb_service() -> KbService:
|
||||||
|
"""Return the process-wide :class:`KbService` (lazy singleton)."""
|
||||||
|
global _kb_service
|
||||||
|
if _kb_service is None:
|
||||||
|
_kb_service = KbService()
|
||||||
|
return _kb_service
|
||||||
|
|
||||||
|
|
||||||
|
def set_kb_service(service: KbService | None) -> None:
|
||||||
|
"""Inject a custom :class:`KbService` (used by tests)."""
|
||||||
|
global _kb_service
|
||||||
|
_kb_service = service
|
||||||
|
|
@ -0,0 +1,428 @@
|
||||||
|
"""SkillService — admin-driven skill enable/disable + import/reload (U6).
|
||||||
|
|
||||||
|
This module is the single owner of the ``skill_states`` table (admin
|
||||||
|
disable markers) and the YAML import/reload pipeline. Web UI routes
|
||||||
|
(``/api/v1/admin/skills/*``) and the existing skill-management route
|
||||||
|
both call into :class:`SkillService` rather than touching the table
|
||||||
|
or filesystem directly, keeping validation rules (skill-name regex,
|
||||||
|
path-traversal guard, YAML validation) in one place.
|
||||||
|
|
||||||
|
The service is a module-level singleton (see :func:`get_skill_service`)
|
||||||
|
so tests can inject a custom instance via :func:`set_skill_service`.
|
||||||
|
|
||||||
|
Design notes
|
||||||
|
------------
|
||||||
|
- Each method opens its own short-lived :class:`aiosqlite.Connection`
|
||||||
|
(mirrors :class:`DepartmentService`).
|
||||||
|
- ``import_skill`` reuses the validation regex from
|
||||||
|
:mod:`agentkit.server.routes.skills` but inlines a copy to avoid a
|
||||||
|
circular import (routes.skills → admin.skill_service → routes.skills).
|
||||||
|
- ``reload_skill`` and ``update_skill_config`` accept a
|
||||||
|
:class:`SkillRegistry` instance so the caller can pass the live
|
||||||
|
registry from ``app.state``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from agentkit.server.auth.models import skill_state_row_to_dict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from agentkit.skills.registry import SkillRegistry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Constants & helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Strict skill name validation: lowercase alphanumeric, hyphens, underscores.
|
||||||
|
# Inlined from routes/skills.py to avoid a circular import.
|
||||||
|
_SKILL_NAME_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
|
||||||
|
|
||||||
|
|
||||||
|
def _now_iso() -> str:
|
||||||
|
"""Return current UTC time as ISO 8601 string."""
|
||||||
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_skill_name(name: str) -> str:
|
||||||
|
"""Validate and normalize a skill name. Raises ``ValueError`` on invalid input.
|
||||||
|
|
||||||
|
Mirrors the regex in :mod:`agentkit.server.routes.skills` but raises
|
||||||
|
``ValueError`` (not ``HTTPException``) so the service is HTTP-agnostic.
|
||||||
|
"""
|
||||||
|
if not isinstance(name, str):
|
||||||
|
raise ValueError("Skill name must be a string")
|
||||||
|
normalized = name.strip().lower()
|
||||||
|
if not _SKILL_NAME_RE.match(normalized):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid skill name {name!r}: must contain only lowercase letters, "
|
||||||
|
f"digits, hyphens, and underscores (1-64 chars)"
|
||||||
|
)
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_yaml_content(content: str) -> dict[str, Any]:
|
||||||
|
"""Validate YAML content. Returns parsed dict. Raises ``ValueError`` on invalid input.
|
||||||
|
|
||||||
|
Mirrors the validation in :mod:`agentkit.server.routes.skills` but
|
||||||
|
raises ``ValueError`` (not ``HTTPException``).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = yaml.safe_load(content)
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
raise ValueError(f"Invalid YAML content: {exc}") from exc
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise ValueError("Skill YAML must be a mapping/dict")
|
||||||
|
|
||||||
|
if "name" not in data:
|
||||||
|
raise ValueError("Skill YAML must contain a 'name' field")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _is_within_directory(file_path: str, directory: str) -> bool:
|
||||||
|
"""Return ``True`` if ``file_path`` resolves inside ``directory``.
|
||||||
|
|
||||||
|
Uses :func:`os.path.realpath` to canonicalize both paths before
|
||||||
|
comparison, so symlink traversal and ``..`` segments are detected.
|
||||||
|
"""
|
||||||
|
real_file = os.path.realpath(file_path)
|
||||||
|
real_dir = os.path.realpath(directory)
|
||||||
|
return real_file.startswith(real_dir + os.sep) or real_file == real_dir
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Service
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class SkillService:
|
||||||
|
"""Admin-driven skill enable/disable + import/reload operations.
|
||||||
|
|
||||||
|
All DB-touching methods are async and take ``db_path: Path`` as the
|
||||||
|
first argument. YAML-touching methods take ``skills_dir: str`` and
|
||||||
|
optionally a :class:`SkillRegistry` instance for live reload.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Enable / disable (DB state)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def disable_skill(
|
||||||
|
self,
|
||||||
|
db_path: Path,
|
||||||
|
skill_name: str,
|
||||||
|
disabled_by: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Mark a skill as disabled in the DB.
|
||||||
|
|
||||||
|
Idempotent: if the skill is already disabled, the row is updated
|
||||||
|
(refreshing ``disabled_at`` and ``disabled_by``).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to the auth SQLite DB.
|
||||||
|
skill_name: Skill name to disable (validated against the regex).
|
||||||
|
disabled_by: Optional admin user id (audit trail).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The disabled skill state dict.
|
||||||
|
"""
|
||||||
|
normalized = _validate_skill_name(skill_name)
|
||||||
|
now = _now_iso()
|
||||||
|
async with aiosqlite.connect(str(db_path)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO skill_states (skill_name, is_disabled, disabled_at, disabled_by) "
|
||||||
|
"VALUES (?, 1, ?, ?) "
|
||||||
|
"ON CONFLICT(skill_name) DO UPDATE SET "
|
||||||
|
" is_disabled = excluded.is_disabled, "
|
||||||
|
" disabled_at = excluded.disabled_at, "
|
||||||
|
" disabled_by = excluded.disabled_by",
|
||||||
|
(normalized, now, disabled_by),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT * FROM skill_states WHERE skill_name = ?",
|
||||||
|
(normalized,),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert row is not None # we just upserted it
|
||||||
|
return skill_state_row_to_dict(row)
|
||||||
|
|
||||||
|
async def enable_skill(
|
||||||
|
self,
|
||||||
|
db_path: Path,
|
||||||
|
skill_name: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Remove the disabled mark for a skill.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``True`` if a row was deleted (skill was previously disabled),
|
||||||
|
``False`` if the skill was not disabled.
|
||||||
|
"""
|
||||||
|
normalized = _validate_skill_name(skill_name)
|
||||||
|
async with aiosqlite.connect(str(db_path)) as db:
|
||||||
|
cursor = await db.execute(
|
||||||
|
"DELETE FROM skill_states WHERE skill_name = ?",
|
||||||
|
(normalized,),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
|
async def is_skill_disabled(
|
||||||
|
self,
|
||||||
|
db_path: Path,
|
||||||
|
skill_name: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Return ``True`` if the skill is currently disabled."""
|
||||||
|
normalized = _validate_skill_name(skill_name)
|
||||||
|
async with aiosqlite.connect(str(db_path)) as db:
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT is_disabled FROM skill_states WHERE skill_name = ?",
|
||||||
|
(normalized,),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row is None:
|
||||||
|
return False
|
||||||
|
return bool(row[0])
|
||||||
|
|
||||||
|
async def list_disabled_skills(
|
||||||
|
self,
|
||||||
|
db_path: Path,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Return the list of skill names currently disabled."""
|
||||||
|
async with aiosqlite.connect(str(db_path)) as db:
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT skill_name FROM skill_states WHERE is_disabled = 1 ORDER BY skill_name ASC"
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
return [row[0] for row in rows]
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# YAML import / reload / update
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def import_skill(
|
||||||
|
self,
|
||||||
|
yaml_content: str,
|
||||||
|
skills_dir: str,
|
||||||
|
skill_registry: SkillRegistry | None = None,
|
||||||
|
tool_registry: Any = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Validate YAML, write to file, and load into the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
yaml_content: Raw YAML text for the skill config.
|
||||||
|
skills_dir: Target directory for the YAML file. The file is
|
||||||
|
written as ``{skills_dir}/{skill_name}.yaml``.
|
||||||
|
skill_registry: Optional :class:`SkillRegistry` to register
|
||||||
|
the loaded skill into. If ``None``, the skill is only
|
||||||
|
written to disk.
|
||||||
|
tool_registry: Optional :class:`ToolRegistry` for tool binding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with ``name`` and ``path`` of the imported skill.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the YAML is invalid, the skill name is invalid,
|
||||||
|
or the resolved path escapes ``skills_dir``.
|
||||||
|
"""
|
||||||
|
data = _validate_yaml_content(yaml_content)
|
||||||
|
raw_name = data["name"]
|
||||||
|
skill_name = _validate_skill_name(raw_name)
|
||||||
|
|
||||||
|
os.makedirs(skills_dir, exist_ok=True)
|
||||||
|
file_path = os.path.join(skills_dir, f"{skill_name}.yaml")
|
||||||
|
|
||||||
|
# Path traversal protection: the resolved path must stay inside
|
||||||
|
# skills_dir. (skill_name is regex-validated so this should always
|
||||||
|
# pass, but we double-check defensively.)
|
||||||
|
if not _is_within_directory(file_path, skills_dir):
|
||||||
|
raise ValueError(f"Resolved skill path escapes skills_dir: {file_path}")
|
||||||
|
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(yaml_content)
|
||||||
|
|
||||||
|
if skill_registry is not None:
|
||||||
|
try:
|
||||||
|
from agentkit.skills.loader import SkillLoader
|
||||||
|
|
||||||
|
loader = SkillLoader(
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
tool_registry=tool_registry,
|
||||||
|
)
|
||||||
|
loader.load_from_file(file_path)
|
||||||
|
except Exception:
|
||||||
|
# Remove the invalid YAML file and re-raise as ValueError
|
||||||
|
# so the route layer maps it to 400.
|
||||||
|
try:
|
||||||
|
os.remove(file_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
raise ValueError("Skill YAML written but registration failed")
|
||||||
|
|
||||||
|
return {"name": skill_name, "path": file_path}
|
||||||
|
|
||||||
|
async def reload_skill(
|
||||||
|
self,
|
||||||
|
skill_name: str,
|
||||||
|
skill_registry: SkillRegistry,
|
||||||
|
skills_dir: str,
|
||||||
|
tool_registry: Any = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Unregister and reload a skill from its YAML file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_name: Skill name to reload (validated against the regex).
|
||||||
|
skill_registry: The live :class:`SkillRegistry` instance.
|
||||||
|
skills_dir: Directory containing the YAML file.
|
||||||
|
tool_registry: Optional :class:`ToolRegistry` for tool binding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with ``name``, ``path``, and ``status`` of the reload.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the skill name is invalid or the YAML file
|
||||||
|
cannot be found.
|
||||||
|
"""
|
||||||
|
normalized = _validate_skill_name(skill_name)
|
||||||
|
file_path = os.path.join(skills_dir, f"{normalized}.yaml")
|
||||||
|
if not os.path.isfile(file_path):
|
||||||
|
raise ValueError(f"Skill {normalized!r} not found: no YAML file at {file_path}")
|
||||||
|
|
||||||
|
if not _is_within_directory(file_path, skills_dir):
|
||||||
|
raise ValueError(f"Resolved skill path escapes skills_dir: {file_path}")
|
||||||
|
|
||||||
|
# Unregister the existing skill (if any) before reloading so the
|
||||||
|
# registry picks up the new config cleanly.
|
||||||
|
try:
|
||||||
|
skill_registry.unregister(normalized)
|
||||||
|
except Exception: # noqa: BLE001 — unregister is best-effort
|
||||||
|
logger.debug("Skill %r was not registered before reload", normalized)
|
||||||
|
|
||||||
|
from agentkit.skills.loader import SkillLoader
|
||||||
|
|
||||||
|
loader = SkillLoader(
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
tool_registry=tool_registry,
|
||||||
|
)
|
||||||
|
loader.load_from_file(file_path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": normalized,
|
||||||
|
"path": file_path,
|
||||||
|
"status": "reloaded",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def update_skill_config(
|
||||||
|
self,
|
||||||
|
skill_name: str,
|
||||||
|
config_patch: dict[str, Any],
|
||||||
|
skills_dir: str,
|
||||||
|
skill_registry: SkillRegistry,
|
||||||
|
tool_registry: Any = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Update a skill's YAML config in-place and reload it.
|
||||||
|
|
||||||
|
Performs a shallow merge: top-level keys in ``config_patch``
|
||||||
|
overwrite the existing YAML values. Nested dicts are replaced
|
||||||
|
wholesale (not deep-merged) to keep the semantics predictable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_name: Skill name to update (validated against the regex).
|
||||||
|
config_patch: Dict of top-level YAML keys to update.
|
||||||
|
skills_dir: Directory containing the YAML file.
|
||||||
|
skill_registry: The live :class:`SkillRegistry` instance.
|
||||||
|
tool_registry: Optional :class:`ToolRegistry` for tool binding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with ``name``, ``path``, and ``status`` of the update.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the skill name is invalid, the YAML file
|
||||||
|
cannot be found, or the patched config is invalid.
|
||||||
|
"""
|
||||||
|
normalized = _validate_skill_name(skill_name)
|
||||||
|
file_path = os.path.join(skills_dir, f"{normalized}.yaml")
|
||||||
|
if not os.path.isfile(file_path):
|
||||||
|
raise ValueError(f"Skill {normalized!r} not found: no YAML file at {file_path}")
|
||||||
|
|
||||||
|
if not _is_within_directory(file_path, skills_dir):
|
||||||
|
raise ValueError(f"Resolved skill path escapes skills_dir: {file_path}")
|
||||||
|
|
||||||
|
# Read existing YAML, apply patch, validate, write back.
|
||||||
|
with open(file_path, encoding="utf-8") as f:
|
||||||
|
existing = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
if not isinstance(existing, dict):
|
||||||
|
raise ValueError(f"Existing skill YAML at {file_path} is not a mapping")
|
||||||
|
|
||||||
|
# Shallow merge: top-level keys in config_patch overwrite.
|
||||||
|
# The 'name' field is preserved (cannot rename via patch).
|
||||||
|
patched = {**existing, **config_patch, "name": existing.get("name", normalized)}
|
||||||
|
|
||||||
|
# Validate the patched config by round-tripping through YAML.
|
||||||
|
try:
|
||||||
|
patched_yaml = yaml.safe_dump(patched, default_flow_style=False, allow_unicode=True)
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
raise ValueError(f"Failed to serialize patched config: {exc}") from exc
|
||||||
|
|
||||||
|
_validate_yaml_content(patched_yaml)
|
||||||
|
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(patched_yaml)
|
||||||
|
|
||||||
|
# Reload the skill into the registry.
|
||||||
|
try:
|
||||||
|
skill_registry.unregister(normalized)
|
||||||
|
except Exception: # noqa: BLE001 — unregister is best-effort
|
||||||
|
logger.debug("Skill %r was not registered before update", normalized)
|
||||||
|
|
||||||
|
from agentkit.skills.loader import SkillLoader
|
||||||
|
|
||||||
|
loader = SkillLoader(
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
tool_registry=tool_registry,
|
||||||
|
)
|
||||||
|
loader.load_from_file(file_path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": normalized,
|
||||||
|
"path": file_path,
|
||||||
|
"status": "updated",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Module-level singleton (overridable in tests via set_skill_service)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
_skill_service: SkillService | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_skill_service() -> SkillService:
|
||||||
|
"""Return the process-wide :class:`SkillService` (lazy singleton)."""
|
||||||
|
global _skill_service
|
||||||
|
if _skill_service is None:
|
||||||
|
_skill_service = SkillService()
|
||||||
|
return _skill_service
|
||||||
|
|
||||||
|
|
||||||
|
def set_skill_service(service: SkillService | None) -> None:
|
||||||
|
"""Inject a custom :class:`SkillService` (used by tests)."""
|
||||||
|
global _skill_service
|
||||||
|
_skill_service = service
|
||||||
|
|
@ -351,6 +351,24 @@ class DepartmentQuotaModel(Base):
|
||||||
updated_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso)
|
updated_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso)
|
||||||
|
|
||||||
|
|
||||||
|
class SkillStateModel(Base):
|
||||||
|
"""Skill enable/disable state (V4 — Admin Console U6).
|
||||||
|
|
||||||
|
Each row records whether a named skill has been disabled by an
|
||||||
|
admin via the ``/admin/skills/{name}/disable`` endpoint. Skills
|
||||||
|
with no row here are considered enabled (the default). ``skill_name``
|
||||||
|
references the skill registry identifier (not a DB FK — skills are
|
||||||
|
defined in YAML configs).
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "skill_states"
|
||||||
|
|
||||||
|
skill_name: Mapped[str] = mapped_column(String(128), primary_key=True)
|
||||||
|
is_disabled: Mapped[bool] = mapped_column(default=True, nullable=False)
|
||||||
|
disabled_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso)
|
||||||
|
disabled_by: Mapped[str | None] = mapped_column(String(36), nullable=True)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Schema DDL (kept in sync with the models above for aiosqlite bootstrap)
|
# Schema DDL (kept in sync with the models above for aiosqlite bootstrap)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -565,6 +583,17 @@ CREATE TABLE IF NOT EXISTS department_quotas (
|
||||||
);
|
);
|
||||||
CREATE INDEX IF NOT EXISTS idx_department_quotas_department_id
|
CREATE INDEX IF NOT EXISTS idx_department_quotas_department_id
|
||||||
ON department_quotas(department_id);
|
ON department_quotas(department_id);
|
||||||
|
|
||||||
|
-- V4: skill_states records admin-disabled skills (U6 — Admin Console).
|
||||||
|
-- A skill with no row here is considered enabled (the default). Only
|
||||||
|
-- disabled skills have a row, with is_disabled=1. disabled_by records
|
||||||
|
-- the admin user id who disabled the skill (audit trail).
|
||||||
|
CREATE TABLE IF NOT EXISTS skill_states (
|
||||||
|
skill_name TEXT PRIMARY KEY,
|
||||||
|
is_disabled INTEGER NOT NULL DEFAULT 1,
|
||||||
|
disabled_at TEXT NOT NULL,
|
||||||
|
disabled_by TEXT
|
||||||
|
);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -580,7 +609,10 @@ CREATE INDEX IF NOT EXISTS idx_department_quotas_department_id
|
||||||
# V3 (2026-06-21, Admin Console): added departments, user_departments,
|
# V3 (2026-06-21, Admin Console): added departments, user_departments,
|
||||||
# department_skill_bindings, department_kb_bindings, department_quotas.
|
# department_skill_bindings, department_kb_bindings, department_quotas.
|
||||||
# No backfill needed — all new tables are additive.
|
# No backfill needed — all new tables are additive.
|
||||||
_SCHEMA_VERSION = 3
|
#
|
||||||
|
# V4 (2026-06-21, Admin Console U6): added skill_states table for
|
||||||
|
# admin-driven skill enable/disable. No backfill needed — additive.
|
||||||
|
_SCHEMA_VERSION = 4
|
||||||
|
|
||||||
_META_SCHEMA_VERSION_KEY = "schema_version"
|
_META_SCHEMA_VERSION_KEY = "schema_version"
|
||||||
|
|
||||||
|
|
@ -803,3 +835,18 @@ def user_department_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> di
|
||||||
"department_id": row["department_id"],
|
"department_id": row["department_id"],
|
||||||
"created_at": row["created_at"],
|
"created_at": row["created_at"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def skill_state_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any]:
|
||||||
|
"""Convert a ``skill_states`` row into a JSON-safe dict.
|
||||||
|
|
||||||
|
The ``is_disabled`` field is normalized to a Python ``bool`` (the DB
|
||||||
|
stores 0/1). ``disabled_by`` is ``None`` when the disabling admin
|
||||||
|
is not recorded (e.g. legacy rows or system-initiated disables).
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"skill_name": row["skill_name"],
|
||||||
|
"is_disabled": bool(row["is_disabled"]),
|
||||||
|
"disabled_at": row["disabled_at"],
|
||||||
|
"disabled_by": row["disabled_by"],
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -23,11 +23,13 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from agentkit.server.admin.department_service import get_department_service
|
from agentkit.server.admin.department_service import get_department_service
|
||||||
|
from agentkit.server.admin.kb_service import get_kb_service
|
||||||
from agentkit.server.admin.llm_config_service import (
|
from agentkit.server.admin.llm_config_service import (
|
||||||
LlmConfigService,
|
LlmConfigService,
|
||||||
get_llm_config_service,
|
get_llm_config_service,
|
||||||
)
|
)
|
||||||
from agentkit.server.admin.quota_service import get_quota_service
|
from agentkit.server.admin.quota_service import get_quota_service
|
||||||
|
from agentkit.server.admin.skill_service import get_skill_service
|
||||||
from agentkit.server.admin.user_service import get_user_service
|
from agentkit.server.admin.user_service import get_user_service
|
||||||
from agentkit.server.auth.dependencies import require_authenticated
|
from agentkit.server.auth.dependencies import require_authenticated
|
||||||
from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH, init_auth_db
|
from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH, init_auth_db
|
||||||
|
|
@ -854,3 +856,285 @@ async def delete_department_quota(
|
||||||
if not deleted:
|
if not deleted:
|
||||||
raise HTTPException(status_code=404, detail="Quota not found")
|
raise HTTPException(status_code=404, detail="Quota not found")
|
||||||
return {"deleted": True}
|
return {"deleted": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Skill management endpoints (U6) — enable/disable + import/reload
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _get_skills_dir(request: Request) -> str:
|
||||||
|
"""Resolve the skills directory from ``app.state.server_config``.
|
||||||
|
|
||||||
|
Mirrors the logic in :func:`agentkit.server.routes.skills._get_skills_dir`
|
||||||
|
so admin endpoints install/reload skills into the same directory as
|
||||||
|
the existing ``/skills/install`` route.
|
||||||
|
"""
|
||||||
|
server_config = getattr(request.app.state, "server_config", None)
|
||||||
|
if server_config and getattr(server_config, "skill_paths", None):
|
||||||
|
first_path = Path(server_config.skill_paths[0])
|
||||||
|
if first_path.is_dir():
|
||||||
|
return str(first_path)
|
||||||
|
# Fallback: configs/skills/ relative to cwd (matches routes.skills).
|
||||||
|
import os
|
||||||
|
|
||||||
|
return os.path.join(os.getcwd(), "configs", "skills")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_skill_registry(request: Request) -> Any:
|
||||||
|
"""Return the live :class:`SkillRegistry` from ``app.state``.
|
||||||
|
|
||||||
|
Raises HTTPException(500) if the registry is missing — admin skill
|
||||||
|
endpoints cannot function without it.
|
||||||
|
"""
|
||||||
|
registry = getattr(request.app.state, "skill_registry", None)
|
||||||
|
if registry is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Skill registry not initialized on app.state",
|
||||||
|
)
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
class SkillImportRequest(BaseModel):
|
||||||
|
"""Body for ``POST /admin/skills/import``."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
yaml_content: str
|
||||||
|
|
||||||
|
|
||||||
|
class SkillUpdateRequest(BaseModel):
|
||||||
|
"""Body for ``PATCH /admin/skills/{name}``."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
config: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.post("/skills/{name}/enable")
|
||||||
|
async def enable_skill(
|
||||||
|
name: str,
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Enable a previously-disabled skill.
|
||||||
|
|
||||||
|
Returns 200 ``{enabled: true}`` on success. Idempotent — returns
|
||||||
|
200 even if the skill was not disabled.
|
||||||
|
"""
|
||||||
|
db_path = await _ensure_db(request)
|
||||||
|
svc = get_skill_service()
|
||||||
|
try:
|
||||||
|
enabled = await svc.enable_skill(db_path, name)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||||
|
return {"enabled": enabled, "skill_name": name}
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.post("/skills/{name}/disable")
|
||||||
|
async def disable_skill(
|
||||||
|
name: str,
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Disable a skill (hides it from ``GET /skills``).
|
||||||
|
|
||||||
|
Returns 200 with the disabled skill state dict.
|
||||||
|
"""
|
||||||
|
db_path = await _ensure_db(request)
|
||||||
|
svc = get_skill_service()
|
||||||
|
try:
|
||||||
|
return await svc.disable_skill(db_path, name, disabled_by=admin.get("user_id"))
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.patch("/skills/{name}")
|
||||||
|
async def update_skill(
|
||||||
|
name: str,
|
||||||
|
payload: SkillUpdateRequest,
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Update a skill's YAML config in-place and reload it.
|
||||||
|
|
||||||
|
Returns 200 with the updated skill info. Returns 404 if the skill
|
||||||
|
YAML file does not exist, 400 if the patched config is invalid.
|
||||||
|
"""
|
||||||
|
skills_dir = _get_skills_dir(request)
|
||||||
|
registry = _get_skill_registry(request)
|
||||||
|
svc = get_skill_service()
|
||||||
|
try:
|
||||||
|
return await svc.update_skill_config(
|
||||||
|
name,
|
||||||
|
payload.config,
|
||||||
|
skills_dir,
|
||||||
|
registry,
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
msg = str(exc)
|
||||||
|
if "not found" in msg:
|
||||||
|
raise HTTPException(status_code=404, detail=msg) from exc
|
||||||
|
raise HTTPException(status_code=400, detail=msg) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.post("/skills/import")
|
||||||
|
async def import_skill(
|
||||||
|
payload: SkillImportRequest,
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Import a skill from YAML content.
|
||||||
|
|
||||||
|
Writes the YAML to ``{skills_dir}/{name}.yaml`` and registers it
|
||||||
|
in the live :class:`SkillRegistry`. Returns 200 with the imported
|
||||||
|
skill info. Returns 400 if the YAML is invalid or the skill name
|
||||||
|
is invalid.
|
||||||
|
"""
|
||||||
|
skills_dir = _get_skills_dir(request)
|
||||||
|
registry = _get_skill_registry(request)
|
||||||
|
svc = get_skill_service()
|
||||||
|
try:
|
||||||
|
return await svc.import_skill(
|
||||||
|
payload.yaml_content,
|
||||||
|
skills_dir,
|
||||||
|
skill_registry=registry,
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.post("/skills/{name}/reload")
|
||||||
|
async def reload_skill(
|
||||||
|
name: str,
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Reload a skill from its YAML file.
|
||||||
|
|
||||||
|
Returns 200 with the reloaded skill info. Returns 404 if the YAML
|
||||||
|
file does not exist, 400 if the skill name is invalid.
|
||||||
|
"""
|
||||||
|
skills_dir = _get_skills_dir(request)
|
||||||
|
registry = _get_skill_registry(request)
|
||||||
|
svc = get_skill_service()
|
||||||
|
try:
|
||||||
|
return await svc.reload_skill(name, registry, skills_dir)
|
||||||
|
except ValueError as exc:
|
||||||
|
msg = str(exc)
|
||||||
|
if "not found" in msg:
|
||||||
|
raise HTTPException(status_code=404, detail=msg) from exc
|
||||||
|
raise HTTPException(status_code=400, detail=msg) from exc
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# KB management endpoints (U6) — documents + sources
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class KbDocumentUploadRequest(BaseModel):
|
||||||
|
"""Body for ``POST /admin/kb/documents``."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
filename: str
|
||||||
|
content: str
|
||||||
|
source_id: str = ""
|
||||||
|
department_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.get("/kb/documents")
|
||||||
|
async def list_kb_documents(
|
||||||
|
request: Request,
|
||||||
|
source_id: str | None = None,
|
||||||
|
department_id: str | None = None,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""List KB documents (admin sees all).
|
||||||
|
|
||||||
|
Query params: ``source_id`` (optional filter), ``department_id``
|
||||||
|
(optional filter — when set, only documents bound to that
|
||||||
|
department or global documents are returned).
|
||||||
|
"""
|
||||||
|
svc = get_kb_service()
|
||||||
|
# Admin sees everything. If department_id is provided as a query
|
||||||
|
# filter, narrow to that department + global docs.
|
||||||
|
if department_id is not None:
|
||||||
|
dept_ids = [department_id]
|
||||||
|
else:
|
||||||
|
dept_ids = None # admin: no filtering
|
||||||
|
documents = svc.list_documents(department_ids=dept_ids, source_id=source_id)
|
||||||
|
return {"documents": documents}
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.post("/kb/documents", status_code=201)
|
||||||
|
async def upload_kb_document(
|
||||||
|
payload: KbDocumentUploadRequest,
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Upload a KB document with optional department binding.
|
||||||
|
|
||||||
|
Returns 201 with the uploaded document dict.
|
||||||
|
"""
|
||||||
|
svc = get_kb_service()
|
||||||
|
return svc.upload_document(
|
||||||
|
filename=payload.filename,
|
||||||
|
content=payload.content.encode("utf-8"),
|
||||||
|
source_id=payload.source_id,
|
||||||
|
department_id=payload.department_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.delete("/kb/documents/{document_id}")
|
||||||
|
async def delete_kb_document(
|
||||||
|
document_id: str,
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Delete a KB document by id.
|
||||||
|
|
||||||
|
Returns 200 ``{deleted: true}`` on success, 404 if not found.
|
||||||
|
"""
|
||||||
|
svc = get_kb_service()
|
||||||
|
deleted = svc.delete_document(document_id)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=404, detail="Document not found")
|
||||||
|
return {"deleted": True}
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.post("/kb/sources/{source_id}/sync")
|
||||||
|
async def sync_kb_source(
|
||||||
|
source_id: str,
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Trigger a sync for a KB source.
|
||||||
|
|
||||||
|
Returns 200 with the sync status. Returns 404 if the source does
|
||||||
|
not exist.
|
||||||
|
"""
|
||||||
|
svc = get_kb_service()
|
||||||
|
try:
|
||||||
|
return svc.sync_source(source_id)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.post("/kb/sources/{source_id}/rebuild")
|
||||||
|
async def rebuild_kb_source(
|
||||||
|
source_id: str,
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Rebuild the index for a KB source.
|
||||||
|
|
||||||
|
Returns 200 with the rebuild status. Returns 404 if the source
|
||||||
|
does not exist.
|
||||||
|
"""
|
||||||
|
svc = get_kb_service()
|
||||||
|
try:
|
||||||
|
return svc.rebuild_index(source_id)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||||
|
|
|
||||||
|
|
@ -73,6 +73,7 @@ class KnowledgeSource:
|
||||||
status: str = "active"
|
status: str = "active"
|
||||||
document_count: int = 0
|
document_count: int = 0
|
||||||
last_synced: str | None = None
|
last_synced: str | None = None
|
||||||
|
department_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -83,6 +84,7 @@ class UploadedDocument:
|
||||||
chunks: int
|
chunks: int
|
||||||
status: str
|
status: str
|
||||||
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
|
department_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeSourceStore:
|
class KnowledgeSourceStore:
|
||||||
|
|
@ -155,6 +157,23 @@ class KnowledgeSourceStore:
|
||||||
_source_store = KnowledgeSourceStore()
|
_source_store = KnowledgeSourceStore()
|
||||||
|
|
||||||
|
|
||||||
|
def get_source_store() -> KnowledgeSourceStore:
|
||||||
|
"""Return the process-wide :class:`KnowledgeSourceStore` singleton.
|
||||||
|
|
||||||
|
Exposed as a function (rather than importing ``_source_store``
|
||||||
|
directly) so tests can swap it out via monkeypatch, and so the
|
||||||
|
:class:`KbService` wrapper can access the store without reaching
|
||||||
|
into a private module attribute.
|
||||||
|
"""
|
||||||
|
return _source_store
|
||||||
|
|
||||||
|
|
||||||
|
def set_source_store(store: KnowledgeSourceStore) -> None:
|
||||||
|
"""Inject a custom :class:`KnowledgeSourceStore` (used by tests)."""
|
||||||
|
global _source_store
|
||||||
|
_source_store = store
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Request / Response models
|
# Request / Response models
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -54,9 +54,7 @@ def _skill_to_info(skill: Any) -> dict[str, Any]:
|
||||||
if hasattr(skill, "config") and hasattr(skill.config, "capabilities"):
|
if hasattr(skill, "config") and hasattr(skill.config, "capabilities"):
|
||||||
caps = skill.config.capabilities
|
caps = skill.config.capabilities
|
||||||
if isinstance(caps, list):
|
if isinstance(caps, list):
|
||||||
capabilities = [
|
capabilities = [c.tag if hasattr(c, "tag") else str(c) for c in caps]
|
||||||
c.tag if hasattr(c, "tag") else str(c) for c in caps
|
|
||||||
]
|
|
||||||
elif isinstance(caps, dict):
|
elif isinstance(caps, dict):
|
||||||
capabilities = list(caps.keys())
|
capabilities = list(caps.keys())
|
||||||
|
|
||||||
|
|
@ -160,7 +158,7 @@ async def check_skill_health(skill_name: str, req: Request):
|
||||||
"""Check the health of a specific skill."""
|
"""Check the health of a specific skill."""
|
||||||
skill_registry = req.app.state.skill_registry
|
skill_registry = req.app.state.skill_registry
|
||||||
try:
|
try:
|
||||||
skill = skill_registry.get(skill_name)
|
skill_registry.get(skill_name)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||||
|
|
||||||
|
|
@ -213,17 +211,43 @@ async def list_capabilities(req: Request):
|
||||||
|
|
||||||
@router.post("/skill-management/skills/{skill_name}/reload")
|
@router.post("/skill-management/skills/{skill_name}/reload")
|
||||||
async def reload_skill(skill_name: str, req: Request):
|
async def reload_skill(skill_name: str, req: Request):
|
||||||
"""Reload a skill configuration."""
|
"""Reload a skill configuration from its YAML file."""
|
||||||
skill_registry = req.app.state.skill_registry
|
skill_registry = req.app.state.skill_registry
|
||||||
|
# Verify the skill is currently registered (404 if not).
|
||||||
try:
|
try:
|
||||||
skill = skill_registry.get(skill_name)
|
skill_registry.get(skill_name)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||||
|
|
||||||
# In a full implementation, this would reload the skill from its config source
|
# Resolve the skills directory (mirrors routes.skills._get_skills_dir).
|
||||||
# For now, just return success
|
import os
|
||||||
|
|
||||||
|
skills_dir: str
|
||||||
|
server_config = getattr(req.app.state, "server_config", None)
|
||||||
|
if server_config and getattr(server_config, "skill_paths", None):
|
||||||
|
from pathlib import Path as _P
|
||||||
|
|
||||||
|
first_path = _P(server_config.skill_paths[0])
|
||||||
|
if first_path.is_dir():
|
||||||
|
skills_dir = str(first_path)
|
||||||
|
else:
|
||||||
|
skills_dir = os.path.join(os.getcwd(), "configs", "skills")
|
||||||
|
else:
|
||||||
|
skills_dir = os.path.join(os.getcwd(), "configs", "skills")
|
||||||
|
|
||||||
|
from agentkit.server.admin.skill_service import get_skill_service
|
||||||
|
|
||||||
|
svc = get_skill_service()
|
||||||
|
try:
|
||||||
|
result = await svc.reload_skill(skill_name, skill_registry, skills_dir)
|
||||||
|
except ValueError as exc:
|
||||||
|
msg = str(exc)
|
||||||
|
if "not found" in msg:
|
||||||
|
raise HTTPException(status_code=404, detail=msg) from exc
|
||||||
|
raise HTTPException(status_code=400, detail=msg) from exc
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"skill_name": skill_name,
|
"skill_name": result["name"],
|
||||||
"status": "reloaded",
|
"status": result["status"],
|
||||||
"message": f"技能 '{skill_name}' 已重新加载",
|
"message": f"技能 '{skill_name}' 已重新加载",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ def _get_skills_dir(req: Request) -> str:
|
||||||
if server_config and server_config.skill_paths:
|
if server_config and server_config.skill_paths:
|
||||||
# Use the first configured skill path as the install target
|
# Use the first configured skill path as the install target
|
||||||
from pathlib import Path as _P
|
from pathlib import Path as _P
|
||||||
|
|
||||||
first_path = _P(server_config.skill_paths[0])
|
first_path = _P(server_config.skill_paths[0])
|
||||||
if first_path.is_dir():
|
if first_path.is_dir():
|
||||||
return str(first_path)
|
return str(first_path)
|
||||||
|
|
@ -59,12 +60,16 @@ def _get_skills_dir(req: Request) -> str:
|
||||||
def _validate_source_url(source: str) -> None:
|
def _validate_source_url(source: str) -> None:
|
||||||
"""Validate that a source URL points to an allowed domain (SSRF mitigation)."""
|
"""Validate that a source URL points to an allowed domain (SSRF mitigation)."""
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
parsed = urlparse(source)
|
parsed = urlparse(source)
|
||||||
if parsed.scheme not in ("https", "http"):
|
if parsed.scheme not in ("https", "http"):
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid source URL scheme: only http/https allowed")
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Invalid source URL scheme: only http/https allowed"
|
||||||
|
)
|
||||||
# Block private/internal IPs by checking hostname
|
# Block private/internal IPs by checking hostname
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
hostname = parsed.hostname
|
hostname = parsed.hostname
|
||||||
if hostname:
|
if hostname:
|
||||||
try:
|
try:
|
||||||
|
|
@ -82,12 +87,15 @@ def _validate_source_url(source: str) -> None:
|
||||||
# Check domain allowlist for source URLs
|
# Check domain allowlist for source URLs
|
||||||
if hostname and hostname not in _ALLOWED_DOWNLOAD_DOMAINS:
|
if hostname and hostname not in _ALLOWED_DOWNLOAD_DOMAINS:
|
||||||
# Allow but log a warning for non-allowlisted domains
|
# Allow but log a warning for non-allowlisted domains
|
||||||
logger.warning(f"Source URL domain '{hostname}' is not in the allowlist: {_ALLOWED_DOWNLOAD_DOMAINS}")
|
logger.warning(
|
||||||
|
f"Source URL domain '{hostname}' is not in the allowlist: {_ALLOWED_DOWNLOAD_DOMAINS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_yaml_content(content: str) -> dict:
|
def _validate_yaml_content(content: str) -> dict:
|
||||||
"""Validate YAML content before writing to disk. Returns parsed dict."""
|
"""Validate YAML content before writing to disk. Returns parsed dict."""
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = yaml.safe_load(content)
|
data = yaml.safe_load(content)
|
||||||
except yaml.YAMLError as e:
|
except yaml.YAMLError as e:
|
||||||
|
|
@ -156,16 +164,37 @@ async def list_skills(
|
||||||
Admin users (``role == "admin"``) bypass filtering and see all
|
Admin users (``role == "admin"``) bypass filtering and see all
|
||||||
registered skills. Unauthenticated callers (API-key clients) see
|
registered skills. Unauthenticated callers (API-key clients) see
|
||||||
only global skills.
|
only global skills.
|
||||||
|
|
||||||
|
Disabled-skill filtering (U6): skills marked as disabled by an
|
||||||
|
admin via ``POST /admin/skills/{name}/disable`` are excluded from
|
||||||
|
the response for ALL callers (admins included). This keeps the
|
||||||
|
public skill list in sync with the admin enable/disable state.
|
||||||
"""
|
"""
|
||||||
skill_registry = req.app.state.skill_registry
|
skill_registry = req.app.state.skill_registry
|
||||||
skills = skill_registry.list_skills()
|
skills = skill_registry.list_skills()
|
||||||
|
|
||||||
|
# U6: filter out admin-disabled skills (applies to all callers).
|
||||||
|
db_path = _resolve_db_path(req)
|
||||||
|
if db_path.exists():
|
||||||
|
try:
|
||||||
|
from agentkit.server.admin.skill_service import get_skill_service
|
||||||
|
|
||||||
|
svc = get_skill_service()
|
||||||
|
disabled_names = set(await svc.list_disabled_skills(db_path))
|
||||||
|
except Exception: # noqa: BLE001 — never block listing on DB errors
|
||||||
|
logger.exception("Failed to load disabled skills list — skipping filter")
|
||||||
|
disabled_names = set()
|
||||||
|
else:
|
||||||
|
disabled_names = set()
|
||||||
|
|
||||||
|
if disabled_names:
|
||||||
|
skills = [s for s in skills if s.name not in disabled_names]
|
||||||
|
|
||||||
# Admins bypass department filtering.
|
# Admins bypass department filtering.
|
||||||
if not dept_ctx.should_filter:
|
if not dept_ctx.should_filter:
|
||||||
return _serialize_skills(skills)
|
return _serialize_skills(skills)
|
||||||
|
|
||||||
# Non-admin: filter by department bindings.
|
# Non-admin: filter by department bindings.
|
||||||
db_path = _resolve_db_path(req)
|
|
||||||
all_names = [s.name for s in skills]
|
all_names = [s.name for s in skills]
|
||||||
try:
|
try:
|
||||||
visible_names = await filter_skills_by_department(
|
visible_names = await filter_skills_by_department(
|
||||||
|
|
@ -218,15 +247,13 @@ async def mention_suggest(q: str = "", req: Request = None):
|
||||||
|
|
||||||
if query:
|
if query:
|
||||||
skills = [
|
skills = [
|
||||||
s for s in skills
|
s
|
||||||
|
for s in skills
|
||||||
if query in s.name.lower()
|
if query in s.name.lower()
|
||||||
or (s.config.description and query in s.config.description.lower())
|
or (s.config.description and query in s.config.description.lower())
|
||||||
]
|
]
|
||||||
|
|
||||||
return [
|
return [{"name": s.name, "description": s.config.description or ""} for s in skills[:8]]
|
||||||
{"name": s.name, "description": s.config.description or ""}
|
|
||||||
for s in skills[:8]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/skills/install")
|
@router.post("/skills/install")
|
||||||
|
|
@ -246,7 +273,9 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
||||||
if source and source.startswith("http"):
|
if source and source.startswith("http"):
|
||||||
_validate_source_url(source)
|
_validate_source_url(source)
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=30, follow_redirects=True, max_redirects=3) as client:
|
async with httpx.AsyncClient(
|
||||||
|
timeout=30, follow_redirects=True, max_redirects=3
|
||||||
|
) as client:
|
||||||
resp = await client.get(source)
|
resp = await client.get(source)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
yaml_content = resp.text
|
yaml_content = resp.text
|
||||||
|
|
@ -260,7 +289,9 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
||||||
# Verify the path is within the skills directory
|
# Verify the path is within the skills directory
|
||||||
skills_dir_base = _get_skills_dir(req)
|
skills_dir_base = _get_skills_dir(req)
|
||||||
if not os.path.realpath(local_path).startswith(os.path.realpath(skills_dir_base)):
|
if not os.path.realpath(local_path).startswith(os.path.realpath(skills_dir_base)):
|
||||||
raise HTTPException(status_code=400, detail="Local file path must be within the skills directory")
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Local file path must be within the skills directory"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
with open(local_path, encoding="utf-8") as f:
|
with open(local_path, encoding="utf-8") as f:
|
||||||
yaml_content = f.read()
|
yaml_content = f.read()
|
||||||
|
|
@ -290,12 +321,17 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
||||||
# Fallback: try a simpler search
|
# Fallback: try a simpler search
|
||||||
search_query2 = f"{skill_name} skill"
|
search_query2 = f"{skill_name} skill"
|
||||||
encoded_query2 = urllib.parse.quote(search_query2)
|
encoded_query2 = urllib.parse.quote(search_query2)
|
||||||
github_api2 = f"https://api.github.com/search/code?q={encoded_query2}+extension:yaml&per_page=5"
|
github_api2 = (
|
||||||
|
f"https://api.github.com/search/code?q={encoded_query2}+extension:yaml&per_page=5"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=15) as client:
|
async with httpx.AsyncClient(timeout=15) as client:
|
||||||
gh_resp2 = await client.get(
|
gh_resp2 = await client.get(
|
||||||
github_api2,
|
github_api2,
|
||||||
headers={"Accept": "application/vnd.github.v3+json", "User-Agent": "agentkit"},
|
headers={
|
||||||
|
"Accept": "application/vnd.github.v3+json",
|
||||||
|
"User-Agent": "agentkit",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
items = gh_resp2.json().get("items", [])
|
items = gh_resp2.json().get("items", [])
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -310,13 +346,19 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
||||||
if raw_url:
|
if raw_url:
|
||||||
# Validate the URL is from github.com before transforming
|
# Validate the URL is from github.com before transforming
|
||||||
if not raw_url.startswith("https://github.com/"):
|
if not raw_url.startswith("https://github.com/"):
|
||||||
raise HTTPException(status_code=400, detail="Search result URL is not from github.com")
|
raise HTTPException(
|
||||||
raw_url = raw_url.replace("github.com", "raw.githubusercontent.com").replace("/blob/", "/")
|
status_code=400, detail="Search result URL is not from github.com"
|
||||||
|
)
|
||||||
|
raw_url = raw_url.replace("github.com", "raw.githubusercontent.com").replace(
|
||||||
|
"/blob/", "/"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=404, detail="Could not construct download URL")
|
raise HTTPException(status_code=404, detail="Could not construct download URL")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=30, follow_redirects=True, max_redirects=3) as client:
|
async with httpx.AsyncClient(
|
||||||
|
timeout=30, follow_redirects=True, max_redirects=3
|
||||||
|
) as client:
|
||||||
resp = await client.get(raw_url)
|
resp = await client.get(raw_url)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
yaml_content = resp.text
|
yaml_content = resp.text
|
||||||
|
|
@ -342,6 +384,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
||||||
registration_ok = False
|
registration_ok = False
|
||||||
try:
|
try:
|
||||||
from agentkit.skills.loader import SkillLoader
|
from agentkit.skills.loader import SkillLoader
|
||||||
|
|
||||||
loader = SkillLoader(
|
loader = SkillLoader(
|
||||||
skill_registry=skill_registry,
|
skill_registry=skill_registry,
|
||||||
tool_registry=tool_registry,
|
tool_registry=tool_registry,
|
||||||
|
|
@ -357,7 +400,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
||||||
os.remove(file_path)
|
os.remove(file_path)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
raise HTTPException(status_code=500, detail=f"Skill downloaded but registration failed")
|
raise HTTPException(status_code=500, detail="Skill downloaded but registration failed")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "installed",
|
"status": "installed",
|
||||||
|
|
@ -387,7 +430,9 @@ async def uninstall_skill(name: str, req: Request):
|
||||||
yaml_path = os.path.join(skills_dir, f"{validated_name}.yaml")
|
yaml_path = os.path.join(skills_dir, f"{validated_name}.yaml")
|
||||||
|
|
||||||
# Verify resolved path stays within skills_dir
|
# Verify resolved path stays within skills_dir
|
||||||
if os.path.exists(yaml_path) and os.path.realpath(yaml_path).startswith(os.path.realpath(skills_dir)):
|
if os.path.exists(yaml_path) and os.path.realpath(yaml_path).startswith(
|
||||||
|
os.path.realpath(skills_dir)
|
||||||
|
):
|
||||||
os.remove(yaml_path)
|
os.remove(yaml_path)
|
||||||
|
|
||||||
return {"status": "uninstalled", "name": validated_name}
|
return {"status": "uninstalled", "name": validated_name}
|
||||||
|
|
@ -419,8 +464,7 @@ async def create_pipeline(request: CreatePipelineRequest, req: Request):
|
||||||
return {
|
return {
|
||||||
"name": pipeline.name,
|
"name": pipeline.name,
|
||||||
"steps": [
|
"steps": [
|
||||||
{"skill_name": s["skill_name"], "step_index": i}
|
{"skill_name": s["skill_name"], "step_index": i} for i, s in enumerate(request.steps)
|
||||||
for i, s in enumerate(request.steps)
|
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,522 @@
|
||||||
|
"""Integration tests for the skill/KB admin routes (U6).
|
||||||
|
|
||||||
|
Uses FastAPI TestClient with a test app that mounts the
|
||||||
|
``admin_router`` from ``routes.admin`` plus the public ``skills``
|
||||||
|
router from ``routes.skills`` (to verify disabled-skill filtering).
|
||||||
|
The ``_require_admin`` dependency is overridden via
|
||||||
|
``app.dependency_overrides`` so the tests don't need real JWTs.
|
||||||
|
|
||||||
|
The :class:`SkillRegistry` is a real instance with a test skill
|
||||||
|
loaded from a temp YAML file, so import/reload/update endpoints
|
||||||
|
exercise the real SkillLoader pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from agentkit.server.admin.kb_service import set_kb_service
|
||||||
|
from agentkit.server.admin.skill_service import set_skill_service
|
||||||
|
from agentkit.server.auth.models import init_auth_db
|
||||||
|
from agentkit.server.routes import admin as admin_routes_module
|
||||||
|
from agentkit.server.routes import skills as skills_routes_module
|
||||||
|
from agentkit.skills.registry import SkillRegistry
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Test data
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
_VALID_SKILL_YAML = """\
|
||||||
|
name: admin_test_skill
|
||||||
|
agent_type: simple_generation
|
||||||
|
version: "1.0.0"
|
||||||
|
description: "A test skill for admin route testing"
|
||||||
|
task_mode: llm_generate
|
||||||
|
execution_mode: direct
|
||||||
|
max_steps: 1
|
||||||
|
prompt:
|
||||||
|
identity: "Test"
|
||||||
|
instructions: "Handle test"
|
||||||
|
tools: []
|
||||||
|
"""
|
||||||
|
|
||||||
|
_VALID_SKILL_YAML_2 = """\
|
||||||
|
name: another_test_skill
|
||||||
|
agent_type: simple_generation
|
||||||
|
version: "1.0.0"
|
||||||
|
description: "Another test skill"
|
||||||
|
task_mode: llm_generate
|
||||||
|
execution_mode: direct
|
||||||
|
max_steps: 1
|
||||||
|
prompt:
|
||||||
|
identity: "Test2"
|
||||||
|
instructions: "Handle test 2"
|
||||||
|
tools: []
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def tmp_auth_db(tmp_path: Path) -> Path:
|
||||||
|
db_path = tmp_path / "admin_skill_kb.db"
|
||||||
|
await init_auth_db(db_path)
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def skills_dir(tmp_path: Path) -> str:
|
||||||
|
"""A temp skills directory for YAML files."""
|
||||||
|
d = tmp_path / "skills"
|
||||||
|
d.mkdir()
|
||||||
|
return str(d)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def skill_registry() -> SkillRegistry:
|
||||||
|
return SkillRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_singletons():
|
||||||
|
"""Reset SkillService and KbService singletons before/after each test."""
|
||||||
|
set_skill_service(None)
|
||||||
|
set_kb_service(None)
|
||||||
|
yield
|
||||||
|
set_skill_service(None)
|
||||||
|
set_kb_service(None)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_admin_user() -> dict[str, Any]:
|
||||||
|
return {"user_id": "admin-1", "username": "admin", "role": "admin"}
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_forbidden() -> dict[str, Any]:
|
||||||
|
"""Dependency override that simulates a non-admin (403) response."""
|
||||||
|
raise HTTPException(status_code=403, detail="Admin permission required")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin_app(
|
||||||
|
tmp_auth_db: Path,
|
||||||
|
skills_dir: str,
|
||||||
|
skill_registry: SkillRegistry,
|
||||||
|
) -> FastAPI:
|
||||||
|
"""A minimal FastAPI app with admin + skills routers mounted.
|
||||||
|
|
||||||
|
The ``_require_admin`` dependency is overridden to return a fake
|
||||||
|
admin user. The :class:`SkillRegistry` is set on ``app.state`` so
|
||||||
|
the admin skill endpoints can access it.
|
||||||
|
"""
|
||||||
|
app = FastAPI()
|
||||||
|
app.state.auth_db_path = str(tmp_auth_db)
|
||||||
|
app.state.skill_registry = skill_registry
|
||||||
|
|
||||||
|
# Build a minimal server_config with skill_paths pointing at the
|
||||||
|
# temp skills_dir, so the admin endpoints write YAML there.
|
||||||
|
class _FakeServerConfig:
|
||||||
|
skill_paths = [skills_dir]
|
||||||
|
|
||||||
|
app.state.server_config = _FakeServerConfig()
|
||||||
|
|
||||||
|
app.include_router(admin_routes_module.admin_router, prefix="/api/v1")
|
||||||
|
app.include_router(skills_routes_module.router, prefix="/api/v1")
|
||||||
|
|
||||||
|
# Default: allow admin access.
|
||||||
|
app.dependency_overrides[admin_routes_module._require_admin] = lambda: _make_admin_user()
|
||||||
|
# Override the department context to admin (bypass filtering) so
|
||||||
|
# GET /skills returns all skills (the disabled filter still applies).
|
||||||
|
from agentkit.server.admin.context import DepartmentContext
|
||||||
|
|
||||||
|
app.dependency_overrides[skills_routes_module.get_department_context] = lambda: (
|
||||||
|
DepartmentContext(user_id="admin-1", department_ids=[], is_admin=True)
|
||||||
|
)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin_client(admin_app: FastAPI) -> TestClient:
|
||||||
|
return TestClient(admin_app)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_skill_yaml(skills_dir: str, name: str, content: str) -> str:
|
||||||
|
"""Write a skill YAML file to the skills dir and return the path."""
|
||||||
|
path = os.path.join(skills_dir, f"{name}.yaml")
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(content)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Skill enable/disable
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSkillEnableDisable:
|
||||||
|
def test_disable_skill_returns_200(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.post("/api/v1/admin/skills/admin_test_skill/disable")
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
body = resp.json()
|
||||||
|
assert body["skill_name"] == "admin_test_skill"
|
||||||
|
assert body["is_disabled"] is True
|
||||||
|
|
||||||
|
def test_disable_then_get_skills_excludes_it(
|
||||||
|
self,
|
||||||
|
admin_client: TestClient,
|
||||||
|
skills_dir: str,
|
||||||
|
skill_registry: SkillRegistry,
|
||||||
|
):
|
||||||
|
# Write a skill YAML and register it.
|
||||||
|
_write_skill_yaml(skills_dir, "admin_test_skill", _VALID_SKILL_YAML)
|
||||||
|
from agentkit.skills.loader import SkillLoader
|
||||||
|
|
||||||
|
SkillLoader(skill_registry=skill_registry).load_from_file(
|
||||||
|
os.path.join(skills_dir, "admin_test_skill.yaml")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the skill is initially listed.
|
||||||
|
resp = admin_client.get("/api/v1/skills")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
names = [s["name"] for s in resp.json()]
|
||||||
|
assert "admin_test_skill" in names
|
||||||
|
|
||||||
|
# Disable the skill.
|
||||||
|
resp = admin_client.post("/api/v1/admin/skills/admin_test_skill/disable")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# GET /skills should now exclude it.
|
||||||
|
resp = admin_client.get("/api/v1/skills")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
names = [s["name"] for s in resp.json()]
|
||||||
|
assert "admin_test_skill" not in names
|
||||||
|
|
||||||
|
def test_enable_skill_returns_200(
|
||||||
|
self,
|
||||||
|
admin_client: TestClient,
|
||||||
|
skills_dir: str,
|
||||||
|
skill_registry: SkillRegistry,
|
||||||
|
):
|
||||||
|
# Write and register the skill, then disable it.
|
||||||
|
_write_skill_yaml(skills_dir, "admin_test_skill", _VALID_SKILL_YAML)
|
||||||
|
from agentkit.skills.loader import SkillLoader
|
||||||
|
|
||||||
|
SkillLoader(skill_registry=skill_registry).load_from_file(
|
||||||
|
os.path.join(skills_dir, "admin_test_skill.yaml")
|
||||||
|
)
|
||||||
|
admin_client.post("/api/v1/admin/skills/admin_test_skill/disable")
|
||||||
|
|
||||||
|
# Enable it.
|
||||||
|
resp = admin_client.post("/api/v1/admin/skills/admin_test_skill/enable")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["enabled"] is True
|
||||||
|
|
||||||
|
# GET /skills should now include it.
|
||||||
|
resp = admin_client.get("/api/v1/skills")
|
||||||
|
names = [s["name"] for s in resp.json()]
|
||||||
|
assert "admin_test_skill" in names
|
||||||
|
|
||||||
|
def test_enable_skill_not_disabled_returns_200(self, admin_client: TestClient):
|
||||||
|
"""Enabling a skill that wasn't disabled returns enabled=False."""
|
||||||
|
resp = admin_client.post("/api/v1/admin/skills/never_disabled/enable")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["enabled"] is False
|
||||||
|
|
||||||
|
def test_disable_skill_invalid_name_returns_400(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.post("/api/v1/admin/skills/Has Spaces/disable")
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_non_admin_cannot_disable_skill(self, admin_app: FastAPI):
|
||||||
|
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||||
|
client = TestClient(admin_app)
|
||||||
|
resp = client.post("/api/v1/admin/skills/some_skill/disable")
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Skill import
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSkillImport:
|
||||||
|
def test_import_valid_yaml_returns_200(
|
||||||
|
self,
|
||||||
|
admin_client: TestClient,
|
||||||
|
skills_dir: str,
|
||||||
|
skill_registry: SkillRegistry,
|
||||||
|
):
|
||||||
|
resp = admin_client.post(
|
||||||
|
"/api/v1/admin/skills/import",
|
||||||
|
json={"yaml_content": _VALID_SKILL_YAML},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
body = resp.json()
|
||||||
|
assert body["name"] == "admin_test_skill"
|
||||||
|
assert os.path.isfile(body["path"])
|
||||||
|
# Skill should be registered.
|
||||||
|
assert skill_registry.has_skill("admin_test_skill")
|
||||||
|
|
||||||
|
def test_import_invalid_yaml_returns_400(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.post(
|
||||||
|
"/api/v1/admin/skills/import",
|
||||||
|
json={"yaml_content": "not: valid: yaml: ["},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_import_missing_name_returns_400(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.post(
|
||||||
|
"/api/v1/admin/skills/import",
|
||||||
|
json={"yaml_content": "agent_type: test\n"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_import_invalid_skill_name_returns_400(self, admin_client: TestClient):
|
||||||
|
bad_yaml = 'name: "Has Spaces"\nagent_type: test\n'
|
||||||
|
resp = admin_client.post(
|
||||||
|
"/api/v1/admin/skills/import",
|
||||||
|
json={"yaml_content": bad_yaml},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_non_admin_cannot_import(self, admin_app: FastAPI):
|
||||||
|
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||||
|
client = TestClient(admin_app)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/admin/skills/import",
|
||||||
|
json={"yaml_content": _VALID_SKILL_YAML},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Skill reload
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSkillReload:
|
||||||
|
def test_reload_existing_skill_returns_200(
|
||||||
|
self,
|
||||||
|
admin_client: TestClient,
|
||||||
|
skills_dir: str,
|
||||||
|
skill_registry: SkillRegistry,
|
||||||
|
):
|
||||||
|
# Write and register the skill first.
|
||||||
|
_write_skill_yaml(skills_dir, "admin_test_skill", _VALID_SKILL_YAML)
|
||||||
|
from agentkit.skills.loader import SkillLoader
|
||||||
|
|
||||||
|
SkillLoader(skill_registry=skill_registry).load_from_file(
|
||||||
|
os.path.join(skills_dir, "admin_test_skill.yaml")
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = admin_client.post("/api/v1/admin/skills/admin_test_skill/reload")
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
body = resp.json()
|
||||||
|
assert body["name"] == "admin_test_skill"
|
||||||
|
assert body["status"] == "reloaded"
|
||||||
|
|
||||||
|
def test_reload_nonexistent_skill_returns_404(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.post("/api/v1/admin/skills/nonexistent/reload")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Skill update (PATCH)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSkillUpdate:
|
||||||
|
def test_update_skill_returns_200(
|
||||||
|
self,
|
||||||
|
admin_client: TestClient,
|
||||||
|
skills_dir: str,
|
||||||
|
skill_registry: SkillRegistry,
|
||||||
|
):
|
||||||
|
# Write and register the skill first.
|
||||||
|
_write_skill_yaml(skills_dir, "admin_test_skill", _VALID_SKILL_YAML)
|
||||||
|
from agentkit.skills.loader import SkillLoader
|
||||||
|
|
||||||
|
SkillLoader(skill_registry=skill_registry).load_from_file(
|
||||||
|
os.path.join(skills_dir, "admin_test_skill.yaml")
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = admin_client.patch(
|
||||||
|
"/api/v1/admin/skills/admin_test_skill",
|
||||||
|
json={"config": {"description": "Updated via admin API"}},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
body = resp.json()
|
||||||
|
assert body["status"] == "updated"
|
||||||
|
|
||||||
|
# Verify the YAML file was updated.
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
with open(body["path"], encoding="utf-8") as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
assert data["description"] == "Updated via admin API"
|
||||||
|
|
||||||
|
def test_update_nonexistent_skill_returns_404(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.patch(
|
||||||
|
"/api/v1/admin/skills/nonexistent",
|
||||||
|
json={"config": {"description": "x"}},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# KB document endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestKbDocumentRoutes:
|
||||||
|
def test_list_documents_returns_200(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.get("/api/v1/admin/kb/documents")
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
body = resp.json()
|
||||||
|
assert "documents" in body
|
||||||
|
assert isinstance(body["documents"], list)
|
||||||
|
|
||||||
|
def test_upload_document_returns_201_with_department_id(self, admin_client: TestClient):
|
||||||
|
dept_id = str(uuid.uuid4())
|
||||||
|
resp = admin_client.post(
|
||||||
|
"/api/v1/admin/kb/documents",
|
||||||
|
json={
|
||||||
|
"filename": "test.txt",
|
||||||
|
"content": "hello world",
|
||||||
|
"department_id": dept_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201, resp.text
|
||||||
|
body = resp.json()
|
||||||
|
assert body["filename"] == "test.txt"
|
||||||
|
assert body["department_id"] == dept_id
|
||||||
|
assert "document_id" in body
|
||||||
|
|
||||||
|
def test_upload_document_without_department_id(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.post(
|
||||||
|
"/api/v1/admin/kb/documents",
|
||||||
|
json={"filename": "global.txt", "content": "global content"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
body = resp.json()
|
||||||
|
assert body["department_id"] is None
|
||||||
|
|
||||||
|
def test_upload_then_delete_document(self, admin_client: TestClient):
|
||||||
|
# Upload
|
||||||
|
resp = admin_client.post(
|
||||||
|
"/api/v1/admin/kb/documents",
|
||||||
|
json={"filename": "delete_me.txt", "content": "x"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
doc_id = resp.json()["document_id"]
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
resp = admin_client.delete(f"/api/v1/admin/kb/documents/{doc_id}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"deleted": True}
|
||||||
|
|
||||||
|
# Second delete should 404.
|
||||||
|
resp = admin_client.delete(f"/api/v1/admin/kb/documents/{doc_id}")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_delete_nonexistent_document_returns_404(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.delete("/api/v1/admin/kb/documents/nonexistent-id")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_list_documents_filters_by_department_id(self, admin_client: TestClient):
|
||||||
|
dept_a = str(uuid.uuid4())
|
||||||
|
dept_b = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Upload 3 docs: one for dept_a, one for dept_b, one global.
|
||||||
|
admin_client.post(
|
||||||
|
"/api/v1/admin/kb/documents",
|
||||||
|
json={"filename": "a.txt", "content": "a", "department_id": dept_a},
|
||||||
|
)
|
||||||
|
admin_client.post(
|
||||||
|
"/api/v1/admin/kb/documents",
|
||||||
|
json={"filename": "b.txt", "content": "b", "department_id": dept_b},
|
||||||
|
)
|
||||||
|
admin_client.post(
|
||||||
|
"/api/v1/admin/kb/documents",
|
||||||
|
json={"filename": "global.txt", "content": "g"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter by dept_a: should see dept_a's doc + global doc.
|
||||||
|
resp = admin_client.get("/api/v1/admin/kb/documents", params={"department_id": dept_a})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
docs = resp.json()["documents"]
|
||||||
|
dept_ids = {d["department_id"] for d in docs}
|
||||||
|
assert dept_a in dept_ids
|
||||||
|
assert None in dept_ids # global doc
|
||||||
|
assert dept_b not in dept_ids
|
||||||
|
|
||||||
|
def test_non_admin_cannot_list_documents(self, admin_app: FastAPI):
|
||||||
|
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||||
|
client = TestClient(admin_app)
|
||||||
|
resp = client.get("/api/v1/admin/kb/documents")
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
def test_non_admin_cannot_upload_document(self, admin_app: FastAPI):
|
||||||
|
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||||
|
client = TestClient(admin_app)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/admin/kb/documents",
|
||||||
|
json={"filename": "x.txt", "content": "x"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# KB source sync/rebuild
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestKbSourceRoutes:
|
||||||
|
def test_sync_nonexistent_source_returns_404(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.post("/api/v1/admin/kb/sources/nonexistent/sync")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_rebuild_nonexistent_source_returns_404(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.post("/api/v1/admin/kb/sources/nonexistent/rebuild")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_sync_existing_source_returns_200(self, admin_client: TestClient):
|
||||||
|
# Add a source via the KbService (bypassing the route layer).
|
||||||
|
from agentkit.server.admin.kb_service import get_kb_service
|
||||||
|
|
||||||
|
svc = get_kb_service()
|
||||||
|
store = svc._resolve_store()
|
||||||
|
source = store.add_source("Test Source", "local", {})
|
||||||
|
|
||||||
|
resp = admin_client.post(f"/api/v1/admin/kb/sources/{source.id}/sync")
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
body = resp.json()
|
||||||
|
assert body["status"] == "syncing"
|
||||||
|
|
||||||
|
def test_rebuild_existing_source_returns_200(self, admin_client: TestClient):
|
||||||
|
from agentkit.server.admin.kb_service import get_kb_service
|
||||||
|
|
||||||
|
svc = get_kb_service()
|
||||||
|
store = svc._resolve_store()
|
||||||
|
source = store.add_source("Test Source", "local", {})
|
||||||
|
|
||||||
|
resp = admin_client.post(f"/api/v1/admin/kb/sources/{source.id}/rebuild")
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
body = resp.json()
|
||||||
|
assert body["status"] == "rebuilding"
|
||||||
|
|
||||||
|
def test_non_admin_cannot_sync_source(self, admin_app: FastAPI):
|
||||||
|
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||||
|
client = TestClient(admin_app)
|
||||||
|
resp = client.post("/api/v1/admin/kb/sources/any/sync")
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
@ -4,7 +4,7 @@ Covers:
|
||||||
- ``init_auth_db`` creates the new V3 tables (departments, user_departments,
|
- ``init_auth_db`` creates the new V3 tables (departments, user_departments,
|
||||||
department_skill_bindings, department_kb_bindings, department_quotas)
|
department_skill_bindings, department_kb_bindings, department_quotas)
|
||||||
- ``init_auth_db`` is idempotent (calling twice does not error)
|
- ``init_auth_db`` is idempotent (calling twice does not error)
|
||||||
- ``_SCHEMA_VERSION`` is recorded as 3 in ``auth_meta``
|
- ``_SCHEMA_VERSION`` is recorded as 4 in ``auth_meta`` (V4 adds skill_states)
|
||||||
- ``departments`` insert + query round-trip
|
- ``departments`` insert + query round-trip
|
||||||
- ``user_departments`` many-to-many relationship (one user → many departments,
|
- ``user_departments`` many-to-many relationship (one user → many departments,
|
||||||
one department → many users)
|
one department → many users)
|
||||||
|
|
@ -109,9 +109,7 @@ async def _list_index_names(db: aiosqlite.Connection, table: str) -> set[str]:
|
||||||
|
|
||||||
async def _list_table_names(db: aiosqlite.Connection) -> set[str]:
|
async def _list_table_names(db: aiosqlite.Connection) -> set[str]:
|
||||||
"""Return the set of table names in the SQLite file."""
|
"""Return the set of table names in the SQLite file."""
|
||||||
cursor = await db.execute(
|
cursor = await db.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
|
||||||
)
|
|
||||||
rows = await cursor.fetchall()
|
rows = await cursor.fetchall()
|
||||||
return {row[0] for row in rows}
|
return {row[0] for row in rows}
|
||||||
|
|
||||||
|
|
@ -122,9 +120,9 @@ async def _list_table_names(db: aiosqlite.Connection) -> set[str]:
|
||||||
|
|
||||||
|
|
||||||
class TestSchemaVersion:
|
class TestSchemaVersion:
|
||||||
def test_schema_version_is_v3(self):
|
def test_schema_version_is_v4(self):
|
||||||
"""V3 adds the department-scoped admin tables."""
|
"""V4 adds the skill_states table (U6 — Admin Console)."""
|
||||||
assert _SCHEMA_VERSION == 3
|
assert _SCHEMA_VERSION == 4
|
||||||
|
|
||||||
def test_sqlalchemy_model_table_names(self):
|
def test_sqlalchemy_model_table_names(self):
|
||||||
assert DepartmentModel.__tablename__ == "departments"
|
assert DepartmentModel.__tablename__ == "departments"
|
||||||
|
|
@ -165,13 +163,13 @@ class TestInitAuthDbTables:
|
||||||
tables = await _list_table_names(db)
|
tables = await _list_table_names(db)
|
||||||
assert "department_quotas" in tables
|
assert "department_quotas" in tables
|
||||||
|
|
||||||
async def test_records_schema_version_3_in_auth_meta(self, fresh_db: Path):
|
async def test_records_schema_version_4_in_auth_meta(self, fresh_db: Path):
|
||||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
db.row_factory = aiosqlite.Row
|
db.row_factory = aiosqlite.Row
|
||||||
cursor = await db.execute("SELECT value FROM auth_meta WHERE key='schema_version'")
|
cursor = await db.execute("SELECT value FROM auth_meta WHERE key='schema_version'")
|
||||||
row = await cursor.fetchone()
|
row = await cursor.fetchone()
|
||||||
assert row is not None
|
assert row is not None
|
||||||
assert row["value"] == "3"
|
assert row["value"] == "4"
|
||||||
assert row["value"] == str(_SCHEMA_VERSION)
|
assert row["value"] == str(_SCHEMA_VERSION)
|
||||||
|
|
||||||
async def test_init_auth_db_is_idempotent(self, tmp_path: Path):
|
async def test_init_auth_db_is_idempotent(self, tmp_path: Path):
|
||||||
|
|
@ -279,9 +277,7 @@ class TestDepartmentsCrud:
|
||||||
)
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
db.row_factory = aiosqlite.Row
|
db.row_factory = aiosqlite.Row
|
||||||
cursor = await db.execute(
|
cursor = await db.execute("SELECT description FROM departments WHERE id=?", (dept_id,))
|
||||||
"SELECT description FROM departments WHERE id=?", (dept_id,)
|
|
||||||
)
|
|
||||||
row = await cursor.fetchone()
|
row = await cursor.fetchone()
|
||||||
assert row is not None
|
assert row is not None
|
||||||
assert row["description"] is None
|
assert row["description"] is None
|
||||||
|
|
@ -310,8 +306,7 @@ class TestUserDepartmentsManyToMany:
|
||||||
await db.commit()
|
await db.commit()
|
||||||
db.row_factory = aiosqlite.Row
|
db.row_factory = aiosqlite.Row
|
||||||
cursor = await db.execute(
|
cursor = await db.execute(
|
||||||
"SELECT department_id FROM user_departments WHERE user_id=? "
|
"SELECT department_id FROM user_departments WHERE user_id=? ORDER BY department_id",
|
||||||
"ORDER BY department_id",
|
|
||||||
(user_id,),
|
(user_id,),
|
||||||
)
|
)
|
||||||
rows = await cursor.fetchall()
|
rows = await cursor.fetchall()
|
||||||
|
|
@ -336,8 +331,7 @@ class TestUserDepartmentsManyToMany:
|
||||||
await db.commit()
|
await db.commit()
|
||||||
db.row_factory = aiosqlite.Row
|
db.row_factory = aiosqlite.Row
|
||||||
cursor = await db.execute(
|
cursor = await db.execute(
|
||||||
"SELECT user_id FROM user_departments WHERE department_id=? "
|
"SELECT user_id FROM user_departments WHERE department_id=? ORDER BY user_id",
|
||||||
"ORDER BY user_id",
|
|
||||||
(dept_id,),
|
(dept_id,),
|
||||||
)
|
)
|
||||||
rows = await cursor.fetchall()
|
rows = await cursor.fetchall()
|
||||||
|
|
@ -394,9 +388,7 @@ class TestDepartmentSkillBindingsUnique:
|
||||||
)
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
async def test_same_skill_name_in_different_departments_is_allowed(
|
async def test_same_skill_name_in_different_departments_is_allowed(self, fresh_db: Path):
|
||||||
self, fresh_db: Path
|
|
||||||
):
|
|
||||||
dept_a = str(uuid.uuid4())
|
dept_a = str(uuid.uuid4())
|
||||||
dept_b = str(uuid.uuid4())
|
dept_b = str(uuid.uuid4())
|
||||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,386 @@
|
||||||
|
"""Unit tests for SkillService (U6 — skill enable/disable + import/reload).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- disable_skill → is_skill_disabled returns True
|
||||||
|
- enable_skill → is_skill_disabled returns False
|
||||||
|
- list_disabled_skills → returns correct list
|
||||||
|
- import_skill with valid YAML → file written, skill loaded
|
||||||
|
- import_skill with invalid YAML → ValueError
|
||||||
|
- import_skill with path traversal attempt → ValueError
|
||||||
|
- reload_skill → unregister + reload
|
||||||
|
- update_skill_config → YAML updated + reloaded
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.server.admin.skill_service import SkillService, _validate_skill_name
|
||||||
|
from agentkit.server.auth.models import init_auth_db
|
||||||
|
from agentkit.skills.registry import SkillRegistry
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def fresh_db(tmp_path: Path) -> Path:
|
||||||
|
"""A brand-new auth DB on a fresh path (no data)."""
|
||||||
|
db_path = tmp_path / "auth.db"
|
||||||
|
await init_auth_db(db_path)
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def service() -> SkillService:
|
||||||
|
return SkillService()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def skills_dir(tmp_path: Path) -> str:
|
||||||
|
"""A temp skills directory for YAML files."""
|
||||||
|
d = tmp_path / "skills"
|
||||||
|
d.mkdir()
|
||||||
|
return str(d)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def skill_registry() -> SkillRegistry:
|
||||||
|
return SkillRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
_VALID_SKILL_YAML = """\
|
||||||
|
name: test_skill
|
||||||
|
agent_type: simple_generation
|
||||||
|
version: "1.0.0"
|
||||||
|
description: "A test skill for unit testing"
|
||||||
|
task_mode: llm_generate
|
||||||
|
execution_mode: direct
|
||||||
|
max_steps: 1
|
||||||
|
prompt:
|
||||||
|
identity: "Test"
|
||||||
|
instructions: "Handle test"
|
||||||
|
tools: []
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Enable / disable
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDisableEnable:
|
||||||
|
async def test_disable_skill_marks_disabled(self, service: SkillService, fresh_db: Path):
|
||||||
|
result = await service.disable_skill(fresh_db, "test_skill", disabled_by="admin-1")
|
||||||
|
assert result["skill_name"] == "test_skill"
|
||||||
|
assert result["is_disabled"] is True
|
||||||
|
assert result["disabled_by"] == "admin-1"
|
||||||
|
assert "disabled_at" in result
|
||||||
|
|
||||||
|
assert await service.is_skill_disabled(fresh_db, "test_skill") is True
|
||||||
|
|
||||||
|
async def test_disable_skill_normalizes_name(self, service: SkillService, fresh_db: Path):
|
||||||
|
"""Skill names are normalized to lowercase before storage."""
|
||||||
|
await service.disable_skill(fresh_db, "TestSkill")
|
||||||
|
assert await service.is_skill_disabled(fresh_db, "testskill") is True
|
||||||
|
assert await service.is_skill_disabled(fresh_db, "TestSkill") is True
|
||||||
|
|
||||||
|
async def test_disable_skill_is_idempotent(self, service: SkillService, fresh_db: Path):
|
||||||
|
"""Disabling an already-disabled skill updates the row, not duplicates."""
|
||||||
|
await service.disable_skill(fresh_db, "skill_a", disabled_by="admin-1")
|
||||||
|
await service.disable_skill(fresh_db, "skill_a", disabled_by="admin-2")
|
||||||
|
disabled = await service.list_disabled_skills(fresh_db)
|
||||||
|
assert disabled == ["skill_a"]
|
||||||
|
|
||||||
|
async def test_enable_skill_removes_disabled_mark(self, service: SkillService, fresh_db: Path):
|
||||||
|
await service.disable_skill(fresh_db, "skill_b")
|
||||||
|
assert await service.is_skill_disabled(fresh_db, "skill_b") is True
|
||||||
|
|
||||||
|
enabled = await service.enable_skill(fresh_db, "skill_b")
|
||||||
|
assert enabled is True
|
||||||
|
assert await service.is_skill_disabled(fresh_db, "skill_b") is False
|
||||||
|
|
||||||
|
async def test_enable_skill_returns_false_if_not_disabled(
|
||||||
|
self, service: SkillService, fresh_db: Path
|
||||||
|
):
|
||||||
|
"""Enabling a skill that wasn't disabled returns False (no-op)."""
|
||||||
|
enabled = await service.enable_skill(fresh_db, "never_disabled")
|
||||||
|
assert enabled is False
|
||||||
|
|
||||||
|
async def test_is_skill_disabled_returns_false_for_unknown(
|
||||||
|
self, service: SkillService, fresh_db: Path
|
||||||
|
):
|
||||||
|
assert await service.is_skill_disabled(fresh_db, "unknown_skill") is False
|
||||||
|
|
||||||
|
async def test_list_disabled_skills_returns_sorted_list(
|
||||||
|
self, service: SkillService, fresh_db: Path
|
||||||
|
):
|
||||||
|
await service.disable_skill(fresh_db, "charlie")
|
||||||
|
await service.disable_skill(fresh_db, "alpha")
|
||||||
|
await service.disable_skill(fresh_db, "bravo")
|
||||||
|
result = await service.list_disabled_skills(fresh_db)
|
||||||
|
assert result == ["alpha", "bravo", "charlie"]
|
||||||
|
|
||||||
|
async def test_list_disabled_skills_empty_when_none_disabled(
|
||||||
|
self, service: SkillService, fresh_db: Path
|
||||||
|
):
|
||||||
|
result = await service.list_disabled_skills(fresh_db)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
async def test_disable_skill_invalid_name_raises(self, service: SkillService, fresh_db: Path):
|
||||||
|
with pytest.raises(ValueError, match="Invalid skill name"):
|
||||||
|
await service.disable_skill(fresh_db, "Invalid Name With Spaces")
|
||||||
|
|
||||||
|
async def test_disable_skill_uppercase_name_normalizes(
|
||||||
|
self, service: SkillService, fresh_db: Path
|
||||||
|
):
|
||||||
|
"""Uppercase names are normalized to lowercase (regex requires lowercase)."""
|
||||||
|
# 'TEST_SKILL' normalizes to 'test_skill' which matches the regex.
|
||||||
|
await service.disable_skill(fresh_db, "TEST_SKILL")
|
||||||
|
assert await service.is_skill_disabled(fresh_db, "test_skill") is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# import_skill
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestImportSkill:
|
||||||
|
async def test_import_skill_writes_file_and_loads(
|
||||||
|
self,
|
||||||
|
service: SkillService,
|
||||||
|
skills_dir: str,
|
||||||
|
skill_registry: SkillRegistry,
|
||||||
|
):
|
||||||
|
result = await service.import_skill(
|
||||||
|
_VALID_SKILL_YAML,
|
||||||
|
skills_dir,
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
)
|
||||||
|
assert result["name"] == "test_skill"
|
||||||
|
assert Path(result["path"]).is_file()
|
||||||
|
# Skill should be registered in the registry.
|
||||||
|
assert skill_registry.has_skill("test_skill")
|
||||||
|
|
||||||
|
async def test_import_skill_without_registry_writes_file_only(
|
||||||
|
self,
|
||||||
|
service: SkillService,
|
||||||
|
skills_dir: str,
|
||||||
|
):
|
||||||
|
result = await service.import_skill(_VALID_SKILL_YAML, skills_dir)
|
||||||
|
assert result["name"] == "test_skill"
|
||||||
|
assert Path(result["path"]).is_file()
|
||||||
|
|
||||||
|
async def test_import_skill_invalid_yaml_raises(
|
||||||
|
self,
|
||||||
|
service: SkillService,
|
||||||
|
skills_dir: str,
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="Invalid YAML"):
|
||||||
|
await service.import_skill("not: valid: yaml: [", skills_dir)
|
||||||
|
|
||||||
|
async def test_import_skill_non_mapping_yaml_raises(
|
||||||
|
self,
|
||||||
|
service: SkillService,
|
||||||
|
skills_dir: str,
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="mapping"):
|
||||||
|
await service.import_skill("- just\n- a\n- list\n", skills_dir)
|
||||||
|
|
||||||
|
async def test_import_skill_missing_name_field_raises(
|
||||||
|
self,
|
||||||
|
service: SkillService,
|
||||||
|
skills_dir: str,
|
||||||
|
):
|
||||||
|
yaml_without_name = "agent_type: simple_generation\n"
|
||||||
|
with pytest.raises(ValueError, match="name"):
|
||||||
|
await service.import_skill(yaml_without_name, skills_dir)
|
||||||
|
|
||||||
|
async def test_import_skill_invalid_name_raises(
|
||||||
|
self,
|
||||||
|
service: SkillService,
|
||||||
|
skills_dir: str,
|
||||||
|
):
|
||||||
|
"""YAML with a name that fails the regex raises ValueError."""
|
||||||
|
bad_yaml = 'name: "Bad Name With Spaces"\nagent_type: test\n'
|
||||||
|
with pytest.raises(ValueError, match="Invalid skill name"):
|
||||||
|
await service.import_skill(bad_yaml, skills_dir)
|
||||||
|
|
||||||
|
async def test_import_skill_path_traversal_blocked(
|
||||||
|
self,
|
||||||
|
service: SkillService,
|
||||||
|
skills_dir: str,
|
||||||
|
):
|
||||||
|
"""A YAML with a name containing path separators is rejected by the regex."""
|
||||||
|
# The regex `^[a-z0-9][a-z0-9_-]{0,63}$` rejects '/' and '..',
|
||||||
|
# so path traversal via the name field is impossible. We verify
|
||||||
|
# this by attempting to import a YAML with a traversal name.
|
||||||
|
traversal_yaml = 'name: "../etc/passwd"\nagent_type: test\n'
|
||||||
|
with pytest.raises(ValueError, match="Invalid skill name"):
|
||||||
|
await service.import_skill(traversal_yaml, skills_dir)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# reload_skill
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestReloadSkill:
|
||||||
|
async def test_reload_skill_unregisters_and_reloads(
|
||||||
|
self,
|
||||||
|
service: SkillService,
|
||||||
|
skills_dir: str,
|
||||||
|
skill_registry: SkillRegistry,
|
||||||
|
):
|
||||||
|
# First import the skill.
|
||||||
|
await service.import_skill(
|
||||||
|
_VALID_SKILL_YAML,
|
||||||
|
skills_dir,
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
)
|
||||||
|
assert skill_registry.has_skill("test_skill")
|
||||||
|
|
||||||
|
# Now reload it.
|
||||||
|
result = await service.reload_skill("test_skill", skill_registry, skills_dir)
|
||||||
|
assert result["name"] == "test_skill"
|
||||||
|
assert result["status"] == "reloaded"
|
||||||
|
assert skill_registry.has_skill("test_skill")
|
||||||
|
|
||||||
|
async def test_reload_skill_missing_yaml_raises(
|
||||||
|
self,
|
||||||
|
service: SkillService,
|
||||||
|
skills_dir: str,
|
||||||
|
skill_registry: SkillRegistry,
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="not found"):
|
||||||
|
await service.reload_skill("nonexistent", skill_registry, skills_dir)
|
||||||
|
|
||||||
|
async def test_reload_skill_invalid_name_raises(
|
||||||
|
self,
|
||||||
|
service: SkillService,
|
||||||
|
skills_dir: str,
|
||||||
|
skill_registry: SkillRegistry,
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="Invalid skill name"):
|
||||||
|
await service.reload_skill("Bad Name", skill_registry, skills_dir)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# update_skill_config
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateSkillConfig:
|
||||||
|
async def test_update_skill_config_updates_yaml_and_reloads(
|
||||||
|
self,
|
||||||
|
service: SkillService,
|
||||||
|
skills_dir: str,
|
||||||
|
skill_registry: SkillRegistry,
|
||||||
|
):
|
||||||
|
# First import the skill.
|
||||||
|
await service.import_skill(
|
||||||
|
_VALID_SKILL_YAML,
|
||||||
|
skills_dir,
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Patch the description.
|
||||||
|
result = await service.update_skill_config(
|
||||||
|
"test_skill",
|
||||||
|
{"description": "Updated description"},
|
||||||
|
skills_dir,
|
||||||
|
skill_registry,
|
||||||
|
)
|
||||||
|
assert result["name"] == "test_skill"
|
||||||
|
assert result["status"] == "updated"
|
||||||
|
|
||||||
|
# Verify the YAML file was updated.
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
with open(result["path"], encoding="utf-8") as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
assert data["description"] == "Updated description"
|
||||||
|
# Original fields should be preserved.
|
||||||
|
assert data["name"] == "test_skill"
|
||||||
|
assert data["agent_type"] == "simple_generation"
|
||||||
|
|
||||||
|
# Skill should still be registered.
|
||||||
|
assert skill_registry.has_skill("test_skill")
|
||||||
|
|
||||||
|
async def test_update_skill_config_missing_yaml_raises(
|
||||||
|
self,
|
||||||
|
service: SkillService,
|
||||||
|
skills_dir: str,
|
||||||
|
skill_registry: SkillRegistry,
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="not found"):
|
||||||
|
await service.update_skill_config(
|
||||||
|
"nonexistent",
|
||||||
|
{"description": "x"},
|
||||||
|
skills_dir,
|
||||||
|
skill_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_update_skill_config_preserves_name(
|
||||||
|
self,
|
||||||
|
service: SkillService,
|
||||||
|
skills_dir: str,
|
||||||
|
skill_registry: SkillRegistry,
|
||||||
|
):
|
||||||
|
"""Patching the 'name' field is ignored — name is preserved."""
|
||||||
|
await service.import_skill(
|
||||||
|
_VALID_SKILL_YAML,
|
||||||
|
skills_dir,
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await service.update_skill_config(
|
||||||
|
"test_skill",
|
||||||
|
{"name": "should_be_ignored", "description": "new"},
|
||||||
|
skills_dir,
|
||||||
|
skill_registry,
|
||||||
|
)
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
with open(result["path"], encoding="utf-8") as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
# Name should be preserved (not renamed).
|
||||||
|
assert data["name"] == "test_skill"
|
||||||
|
assert data["description"] == "new"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _validate_skill_name helper
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateSkillName:
|
||||||
|
def test_valid_name_returns_normalized(self):
|
||||||
|
assert _validate_skill_name("test_skill") == "test_skill"
|
||||||
|
assert _validate_skill_name("TestSkill") == "testskill"
|
||||||
|
assert _validate_skill_name(" spaced ") == "spaced"
|
||||||
|
assert _validate_skill_name("a-b_c-1") == "a-b_c-1"
|
||||||
|
|
||||||
|
def test_invalid_name_raises(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_validate_skill_name("Has Spaces")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_validate_skill_name("")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_validate_skill_name("-leading-dash")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_validate_skill_name("../traversal")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_validate_skill_name("has.dot")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_validate_skill_name("has/slash")
|
||||||
|
|
||||||
|
def test_non_string_raises(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_validate_skill_name(None) # type: ignore[arg-type]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_validate_skill_name(123) # type: ignore[arg-type]
|
||||||
|
|
@ -118,9 +118,9 @@ async def _list_index_names(db: aiosqlite.Connection, table: str) -> set[str]:
|
||||||
|
|
||||||
|
|
||||||
class TestSchemaVersion:
|
class TestSchemaVersion:
|
||||||
def test_schema_version_is_v3(self):
|
def test_schema_version_is_v4(self):
|
||||||
"""The current schema version is 3 (V3 adds department-scoped admin tables)."""
|
"""The current schema version is 4 (V4 adds skill_states table)."""
|
||||||
assert _SCHEMA_VERSION == 3
|
assert _SCHEMA_VERSION == 4
|
||||||
|
|
||||||
def test_sqlalchemy_model_table_name(self):
|
def test_sqlalchemy_model_table_name(self):
|
||||||
assert AuthSessionModel.__tablename__ == "auth_sessions"
|
assert AuthSessionModel.__tablename__ == "auth_sessions"
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,12 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from agentkit.llm.gateway import LLMGateway
|
from agentkit.llm.gateway import LLMGateway
|
||||||
from agentkit.server.app import create_app
|
from agentkit.server.app import create_app
|
||||||
|
from agentkit.server.config import ServerConfig
|
||||||
from agentkit.skills.base import Skill, SkillConfig
|
from agentkit.skills.base import Skill, SkillConfig
|
||||||
from agentkit.skills.registry import SkillRegistry
|
from agentkit.skills.registry import SkillRegistry
|
||||||
from agentkit.tools.registry import ToolRegistry
|
from agentkit.tools.registry import ToolRegistry
|
||||||
|
|
@ -202,7 +201,31 @@ class TestListCapabilities:
|
||||||
|
|
||||||
|
|
||||||
class TestReloadSkill:
|
class TestReloadSkill:
|
||||||
def test_reload_skill(self, client, skill_registry):
|
def test_reload_skill(self, client, skill_registry, tmp_path):
|
||||||
|
# Write a valid YAML file for "reload_skill" to a temp skills dir.
|
||||||
|
skills_dir = tmp_path / "skills"
|
||||||
|
skills_dir.mkdir()
|
||||||
|
yaml_content = (
|
||||||
|
"name: reload_skill\n"
|
||||||
|
'agent_type: "test_type"\n'
|
||||||
|
'version: "1.0.0"\n'
|
||||||
|
'description: "Skill for reload testing"\n'
|
||||||
|
"task_mode: llm_generate\n"
|
||||||
|
"execution_mode: direct\n"
|
||||||
|
"max_steps: 1\n"
|
||||||
|
"intent:\n"
|
||||||
|
' keywords: ["reload"]\n'
|
||||||
|
' description: "reload test"\n'
|
||||||
|
"prompt:\n"
|
||||||
|
' identity: "reload tester"\n'
|
||||||
|
' instructions: "handle reload"\n'
|
||||||
|
"tools: []\n"
|
||||||
|
)
|
||||||
|
(skills_dir / "reload_skill.yaml").write_text(yaml_content, encoding="utf-8")
|
||||||
|
|
||||||
|
# Point the app at the temp skills dir and register the skill so the
|
||||||
|
# 404 guard in the route passes.
|
||||||
|
client.app.state.server_config = ServerConfig(skill_paths=[str(skills_dir)])
|
||||||
_register_skill(skill_registry, "reload_skill")
|
_register_skill(skill_registry, "reload_skill")
|
||||||
|
|
||||||
response = client.post("/api/v1/skill-management/skills/reload_skill/reload")
|
response = client.post("/api/v1/skill-management/skills/reload_skill/reload")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue