387 lines
14 KiB
Python
387 lines
14 KiB
Python
"""Unit tests for SkillService (U6 — skill enable/disable + import/reload).
|
|
|
|
Covers:
|
|
- disable_skill → is_skill_disabled returns True
|
|
- enable_skill → is_skill_disabled returns False
|
|
- list_disabled_skills → returns correct list
|
|
- import_skill with valid YAML → file written, skill loaded
|
|
- import_skill with invalid YAML → ValueError
|
|
- import_skill with path traversal attempt → ValueError
|
|
- reload_skill → unregister + reload
|
|
- update_skill_config → YAML updated + reloaded
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from agentkit.server.admin.skill_service import SkillService, _validate_skill_name
|
|
from agentkit.server.auth.models import init_auth_db
|
|
from agentkit.skills.registry import SkillRegistry
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
async def fresh_db(tmp_path: Path) -> Path:
|
|
"""A brand-new auth DB on a fresh path (no data)."""
|
|
db_path = tmp_path / "auth.db"
|
|
await init_auth_db(db_path)
|
|
return db_path
|
|
|
|
|
|
@pytest.fixture
|
|
def service() -> SkillService:
|
|
return SkillService()
|
|
|
|
|
|
@pytest.fixture
|
|
def skills_dir(tmp_path: Path) -> str:
|
|
"""A temp skills directory for YAML files."""
|
|
d = tmp_path / "skills"
|
|
d.mkdir()
|
|
return str(d)
|
|
|
|
|
|
@pytest.fixture
|
|
def skill_registry() -> SkillRegistry:
|
|
return SkillRegistry()
|
|
|
|
|
|
_VALID_SKILL_YAML = """\
|
|
name: test_skill
|
|
agent_type: simple_generation
|
|
version: "1.0.0"
|
|
description: "A test skill for unit testing"
|
|
task_mode: llm_generate
|
|
execution_mode: direct
|
|
max_steps: 1
|
|
prompt:
|
|
identity: "Test"
|
|
instructions: "Handle test"
|
|
tools: []
|
|
"""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Enable / disable
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDisableEnable:
|
|
async def test_disable_skill_marks_disabled(self, service: SkillService, fresh_db: Path):
|
|
result = await service.disable_skill(fresh_db, "test_skill", disabled_by="admin-1")
|
|
assert result["skill_name"] == "test_skill"
|
|
assert result["is_disabled"] is True
|
|
assert result["disabled_by"] == "admin-1"
|
|
assert "disabled_at" in result
|
|
|
|
assert await service.is_skill_disabled(fresh_db, "test_skill") is True
|
|
|
|
async def test_disable_skill_normalizes_name(self, service: SkillService, fresh_db: Path):
|
|
"""Skill names are normalized to lowercase before storage."""
|
|
await service.disable_skill(fresh_db, "TestSkill")
|
|
assert await service.is_skill_disabled(fresh_db, "testskill") is True
|
|
assert await service.is_skill_disabled(fresh_db, "TestSkill") is True
|
|
|
|
async def test_disable_skill_is_idempotent(self, service: SkillService, fresh_db: Path):
|
|
"""Disabling an already-disabled skill updates the row, not duplicates."""
|
|
await service.disable_skill(fresh_db, "skill_a", disabled_by="admin-1")
|
|
await service.disable_skill(fresh_db, "skill_a", disabled_by="admin-2")
|
|
disabled = await service.list_disabled_skills(fresh_db)
|
|
assert disabled == ["skill_a"]
|
|
|
|
async def test_enable_skill_removes_disabled_mark(self, service: SkillService, fresh_db: Path):
|
|
await service.disable_skill(fresh_db, "skill_b")
|
|
assert await service.is_skill_disabled(fresh_db, "skill_b") is True
|
|
|
|
enabled = await service.enable_skill(fresh_db, "skill_b")
|
|
assert enabled is True
|
|
assert await service.is_skill_disabled(fresh_db, "skill_b") is False
|
|
|
|
async def test_enable_skill_returns_false_if_not_disabled(
|
|
self, service: SkillService, fresh_db: Path
|
|
):
|
|
"""Enabling a skill that wasn't disabled returns False (no-op)."""
|
|
enabled = await service.enable_skill(fresh_db, "never_disabled")
|
|
assert enabled is False
|
|
|
|
async def test_is_skill_disabled_returns_false_for_unknown(
|
|
self, service: SkillService, fresh_db: Path
|
|
):
|
|
assert await service.is_skill_disabled(fresh_db, "unknown_skill") is False
|
|
|
|
async def test_list_disabled_skills_returns_sorted_list(
|
|
self, service: SkillService, fresh_db: Path
|
|
):
|
|
await service.disable_skill(fresh_db, "charlie")
|
|
await service.disable_skill(fresh_db, "alpha")
|
|
await service.disable_skill(fresh_db, "bravo")
|
|
result = await service.list_disabled_skills(fresh_db)
|
|
assert result == ["alpha", "bravo", "charlie"]
|
|
|
|
async def test_list_disabled_skills_empty_when_none_disabled(
|
|
self, service: SkillService, fresh_db: Path
|
|
):
|
|
result = await service.list_disabled_skills(fresh_db)
|
|
assert result == []
|
|
|
|
async def test_disable_skill_invalid_name_raises(self, service: SkillService, fresh_db: Path):
|
|
with pytest.raises(ValueError, match="Invalid skill name"):
|
|
await service.disable_skill(fresh_db, "Invalid Name With Spaces")
|
|
|
|
async def test_disable_skill_uppercase_name_normalizes(
|
|
self, service: SkillService, fresh_db: Path
|
|
):
|
|
"""Uppercase names are normalized to lowercase (regex requires lowercase)."""
|
|
# 'TEST_SKILL' normalizes to 'test_skill' which matches the regex.
|
|
await service.disable_skill(fresh_db, "TEST_SKILL")
|
|
assert await service.is_skill_disabled(fresh_db, "test_skill") is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# import_skill
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestImportSkill:
|
|
async def test_import_skill_writes_file_and_loads(
|
|
self,
|
|
service: SkillService,
|
|
skills_dir: str,
|
|
skill_registry: SkillRegistry,
|
|
):
|
|
result = await service.import_skill(
|
|
_VALID_SKILL_YAML,
|
|
skills_dir,
|
|
skill_registry=skill_registry,
|
|
)
|
|
assert result["name"] == "test_skill"
|
|
assert Path(result["path"]).is_file()
|
|
# Skill should be registered in the registry.
|
|
assert skill_registry.has_skill("test_skill")
|
|
|
|
async def test_import_skill_without_registry_writes_file_only(
|
|
self,
|
|
service: SkillService,
|
|
skills_dir: str,
|
|
):
|
|
result = await service.import_skill(_VALID_SKILL_YAML, skills_dir)
|
|
assert result["name"] == "test_skill"
|
|
assert Path(result["path"]).is_file()
|
|
|
|
async def test_import_skill_invalid_yaml_raises(
|
|
self,
|
|
service: SkillService,
|
|
skills_dir: str,
|
|
):
|
|
with pytest.raises(ValueError, match="Invalid YAML"):
|
|
await service.import_skill("not: valid: yaml: [", skills_dir)
|
|
|
|
async def test_import_skill_non_mapping_yaml_raises(
|
|
self,
|
|
service: SkillService,
|
|
skills_dir: str,
|
|
):
|
|
with pytest.raises(ValueError, match="mapping"):
|
|
await service.import_skill("- just\n- a\n- list\n", skills_dir)
|
|
|
|
async def test_import_skill_missing_name_field_raises(
|
|
self,
|
|
service: SkillService,
|
|
skills_dir: str,
|
|
):
|
|
yaml_without_name = "agent_type: simple_generation\n"
|
|
with pytest.raises(ValueError, match="name"):
|
|
await service.import_skill(yaml_without_name, skills_dir)
|
|
|
|
async def test_import_skill_invalid_name_raises(
|
|
self,
|
|
service: SkillService,
|
|
skills_dir: str,
|
|
):
|
|
"""YAML with a name that fails the regex raises ValueError."""
|
|
bad_yaml = 'name: "Bad Name With Spaces"\nagent_type: test\n'
|
|
with pytest.raises(ValueError, match="Invalid skill name"):
|
|
await service.import_skill(bad_yaml, skills_dir)
|
|
|
|
async def test_import_skill_path_traversal_blocked(
|
|
self,
|
|
service: SkillService,
|
|
skills_dir: str,
|
|
):
|
|
"""A YAML with a name containing path separators is rejected by the regex."""
|
|
# The regex `^[a-z0-9][a-z0-9_-]{0,63}$` rejects '/' and '..',
|
|
# so path traversal via the name field is impossible. We verify
|
|
# this by attempting to import a YAML with a traversal name.
|
|
traversal_yaml = 'name: "../etc/passwd"\nagent_type: test\n'
|
|
with pytest.raises(ValueError, match="Invalid skill name"):
|
|
await service.import_skill(traversal_yaml, skills_dir)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# reload_skill
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestReloadSkill:
|
|
async def test_reload_skill_unregisters_and_reloads(
|
|
self,
|
|
service: SkillService,
|
|
skills_dir: str,
|
|
skill_registry: SkillRegistry,
|
|
):
|
|
# First import the skill.
|
|
await service.import_skill(
|
|
_VALID_SKILL_YAML,
|
|
skills_dir,
|
|
skill_registry=skill_registry,
|
|
)
|
|
assert skill_registry.has_skill("test_skill")
|
|
|
|
# Now reload it.
|
|
result = await service.reload_skill("test_skill", skill_registry, skills_dir)
|
|
assert result["name"] == "test_skill"
|
|
assert result["status"] == "reloaded"
|
|
assert skill_registry.has_skill("test_skill")
|
|
|
|
async def test_reload_skill_missing_yaml_raises(
|
|
self,
|
|
service: SkillService,
|
|
skills_dir: str,
|
|
skill_registry: SkillRegistry,
|
|
):
|
|
with pytest.raises(ValueError, match="not found"):
|
|
await service.reload_skill("nonexistent", skill_registry, skills_dir)
|
|
|
|
async def test_reload_skill_invalid_name_raises(
|
|
self,
|
|
service: SkillService,
|
|
skills_dir: str,
|
|
skill_registry: SkillRegistry,
|
|
):
|
|
with pytest.raises(ValueError, match="Invalid skill name"):
|
|
await service.reload_skill("Bad Name", skill_registry, skills_dir)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# update_skill_config
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestUpdateSkillConfig:
|
|
async def test_update_skill_config_updates_yaml_and_reloads(
|
|
self,
|
|
service: SkillService,
|
|
skills_dir: str,
|
|
skill_registry: SkillRegistry,
|
|
):
|
|
# First import the skill.
|
|
await service.import_skill(
|
|
_VALID_SKILL_YAML,
|
|
skills_dir,
|
|
skill_registry=skill_registry,
|
|
)
|
|
|
|
# Patch the description.
|
|
result = await service.update_skill_config(
|
|
"test_skill",
|
|
{"description": "Updated description"},
|
|
skills_dir,
|
|
skill_registry,
|
|
)
|
|
assert result["name"] == "test_skill"
|
|
assert result["status"] == "updated"
|
|
|
|
# Verify the YAML file was updated.
|
|
import yaml
|
|
|
|
with open(result["path"], encoding="utf-8") as f:
|
|
data = yaml.safe_load(f)
|
|
assert data["description"] == "Updated description"
|
|
# Original fields should be preserved.
|
|
assert data["name"] == "test_skill"
|
|
assert data["agent_type"] == "simple_generation"
|
|
|
|
# Skill should still be registered.
|
|
assert skill_registry.has_skill("test_skill")
|
|
|
|
async def test_update_skill_config_missing_yaml_raises(
|
|
self,
|
|
service: SkillService,
|
|
skills_dir: str,
|
|
skill_registry: SkillRegistry,
|
|
):
|
|
with pytest.raises(ValueError, match="not found"):
|
|
await service.update_skill_config(
|
|
"nonexistent",
|
|
{"description": "x"},
|
|
skills_dir,
|
|
skill_registry,
|
|
)
|
|
|
|
async def test_update_skill_config_preserves_name(
|
|
self,
|
|
service: SkillService,
|
|
skills_dir: str,
|
|
skill_registry: SkillRegistry,
|
|
):
|
|
"""Patching the 'name' field is ignored — name is preserved."""
|
|
await service.import_skill(
|
|
_VALID_SKILL_YAML,
|
|
skills_dir,
|
|
skill_registry=skill_registry,
|
|
)
|
|
|
|
result = await service.update_skill_config(
|
|
"test_skill",
|
|
{"name": "should_be_ignored", "description": "new"},
|
|
skills_dir,
|
|
skill_registry,
|
|
)
|
|
import yaml
|
|
|
|
with open(result["path"], encoding="utf-8") as f:
|
|
data = yaml.safe_load(f)
|
|
# Name should be preserved (not renamed).
|
|
assert data["name"] == "test_skill"
|
|
assert data["description"] == "new"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _validate_skill_name helper
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestValidateSkillName:
|
|
def test_valid_name_returns_normalized(self):
|
|
assert _validate_skill_name("test_skill") == "test_skill"
|
|
assert _validate_skill_name("TestSkill") == "testskill"
|
|
assert _validate_skill_name(" spaced ") == "spaced"
|
|
assert _validate_skill_name("a-b_c-1") == "a-b_c-1"
|
|
|
|
def test_invalid_name_raises(self):
|
|
with pytest.raises(ValueError):
|
|
_validate_skill_name("Has Spaces")
|
|
with pytest.raises(ValueError):
|
|
_validate_skill_name("")
|
|
with pytest.raises(ValueError):
|
|
_validate_skill_name("-leading-dash")
|
|
with pytest.raises(ValueError):
|
|
_validate_skill_name("../traversal")
|
|
with pytest.raises(ValueError):
|
|
_validate_skill_name("has.dot")
|
|
with pytest.raises(ValueError):
|
|
_validate_skill_name("has/slash")
|
|
|
|
def test_non_string_raises(self):
|
|
with pytest.raises(ValueError):
|
|
_validate_skill_name(None) # type: ignore[arg-type]
|
|
with pytest.raises(ValueError):
|
|
_validate_skill_name(123) # type: ignore[arg-type]
|