382 lines
14 KiB
Python
382 lines
14 KiB
Python
"""Skill registration routes"""
|
|
|
|
import logging
|
|
import os
|
|
import re
|
|
import urllib.parse
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, HTTPException, Request
|
|
from pydantic import BaseModel
|
|
from typing import Any
|
|
|
|
from agentkit.skills.base import Skill, SkillConfig
|
|
from agentkit.skills.pipeline import SkillPipeline
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(tags=["skills"])
|
|
|
|
# Strict skill name validation: lowercase alphanumeric, hyphens, underscores
|
|
_SKILL_NAME_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
|
|
|
|
# Allowed domains for source URL downloads (SSRF mitigation)
|
|
_ALLOWED_DOWNLOAD_DOMAINS = {
|
|
"raw.githubusercontent.com",
|
|
"github.com",
|
|
"gist.githubusercontent.com",
|
|
}
|
|
|
|
|
|
def _validate_skill_name(name: str) -> str:
|
|
"""Validate and normalize a skill name. Raises HTTPException on invalid input."""
|
|
normalized = name.strip().lower()
|
|
if not _SKILL_NAME_RE.match(normalized):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid skill name '{name}': must contain only lowercase letters, digits, hyphens, and underscores (1-64 chars)",
|
|
)
|
|
return normalized
|
|
|
|
|
|
def _get_skills_dir(req: Request) -> str:
|
|
"""Get the skills directory from server_config, falling back to configs/skills/."""
|
|
server_config = getattr(req.app.state, "server_config", None)
|
|
if server_config and server_config.skill_paths:
|
|
# Use the first configured skill path as the install target
|
|
from pathlib import Path as _P
|
|
first_path = _P(server_config.skill_paths[0])
|
|
if first_path.is_dir():
|
|
return str(first_path)
|
|
# Fallback: configs/skills/ relative to project root
|
|
return os.path.join(os.getcwd(), "configs", "skills")
|
|
|
|
|
|
def _validate_source_url(source: str) -> None:
|
|
"""Validate that a source URL points to an allowed domain (SSRF mitigation)."""
|
|
from urllib.parse import urlparse
|
|
parsed = urlparse(source)
|
|
if parsed.scheme not in ("https", "http"):
|
|
raise HTTPException(status_code=400, detail=f"Invalid source URL scheme: only http/https allowed")
|
|
# Block private/internal IPs by checking hostname
|
|
import ipaddress
|
|
import socket
|
|
hostname = parsed.hostname
|
|
if hostname:
|
|
try:
|
|
# Resolve hostname to check for private IPs
|
|
resolved = socket.getaddrinfo(hostname, None)
|
|
for family, type_, proto, canonname, sockaddr in resolved:
|
|
ip = ipaddress.ip_address(sockaddr[0])
|
|
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Source URL points to a private/internal address — not allowed",
|
|
)
|
|
except socket.gaierror:
|
|
pass # DNS resolution failed, let httpx handle it
|
|
# Check domain allowlist for source URLs
|
|
if hostname and hostname not in _ALLOWED_DOWNLOAD_DOMAINS:
|
|
# Allow but log a warning for non-allowlisted domains
|
|
logger.warning(f"Source URL domain '{hostname}' is not in the allowlist: {_ALLOWED_DOWNLOAD_DOMAINS}")
|
|
|
|
|
|
def _validate_yaml_content(content: str) -> dict:
|
|
"""Validate YAML content before writing to disk. Returns parsed dict."""
|
|
import yaml
|
|
try:
|
|
data = yaml.safe_load(content)
|
|
except yaml.YAMLError as e:
|
|
raise HTTPException(status_code=400, detail=f"Invalid YAML content: {e}")
|
|
|
|
if not isinstance(data, dict):
|
|
raise HTTPException(status_code=400, detail="Skill YAML must be a mapping/dict")
|
|
|
|
# Require at least a 'name' field
|
|
if "name" not in data:
|
|
raise HTTPException(status_code=400, detail="Skill YAML must contain a 'name' field")
|
|
|
|
return data
|
|
|
|
|
|
class RegisterSkillRequest(BaseModel):
|
|
config: dict[str, Any]
|
|
|
|
|
|
class CreatePipelineRequest(BaseModel):
|
|
name: str
|
|
steps: list[dict[str, Any]]
|
|
|
|
|
|
class ExecutePipelineRequest(BaseModel):
|
|
input_data: dict[str, Any]
|
|
|
|
|
|
class InstallSkillRequest(BaseModel):
|
|
name: str
|
|
source: str | None = None # Optional: URL or "github:user/repo/path"
|
|
|
|
|
|
@router.post("/skills", status_code=201)
|
|
async def register_skill(request: RegisterSkillRequest, req: Request):
|
|
"""Register a Skill"""
|
|
skill_registry = req.app.state.skill_registry
|
|
|
|
try:
|
|
config = SkillConfig.from_dict(request.config)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=422, detail=f"Invalid skill config: {e}")
|
|
|
|
skill = Skill(config=config)
|
|
skill_registry.register(skill)
|
|
|
|
return {
|
|
"name": skill.name,
|
|
"agent_type": skill.config.agent_type,
|
|
"version": skill.config.version,
|
|
"description": skill.config.description,
|
|
}
|
|
|
|
|
|
@router.get("/skills")
|
|
async def list_skills(req: Request):
|
|
"""List all skills with full metadata"""
|
|
skill_registry = req.app.state.skill_registry
|
|
skills = skill_registry.list_skills()
|
|
return [
|
|
{
|
|
"name": s.name,
|
|
"agent_type": s.config.agent_type,
|
|
"version": s.config.version,
|
|
"description": s.config.description or "",
|
|
"task_mode": s.config.task_mode or "",
|
|
"intent_keywords": s.config.intent.keywords if s.config.intent else [],
|
|
"intent_description": s.config.intent.description if s.config.intent else "",
|
|
"tools": s.config.tools or [],
|
|
"bound_tools": [t.name for t in (s.tools or [])],
|
|
"prompt_identity": (s.config.prompt or {}).get("identity", ""),
|
|
}
|
|
for s in skills
|
|
]
|
|
|
|
|
|
@router.post("/skills/install")
|
|
async def install_skill(request: InstallSkillRequest, req: Request):
|
|
"""Search for and install a skill by name.
|
|
|
|
Searches GitHub for agentkit-skill YAML files matching the name,
|
|
downloads the first match, saves it to configs/skills/, and registers it.
|
|
"""
|
|
skill_name = _validate_skill_name(request.name)
|
|
source = request.source
|
|
|
|
skill_registry = req.app.state.skill_registry
|
|
tool_registry = getattr(req.app.state, "tool_registry", None)
|
|
|
|
# If source URL is provided directly, download from it
|
|
if source and source.startswith("http"):
|
|
_validate_source_url(source)
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30, follow_redirects=True, max_redirects=3) as client:
|
|
resp = await client.get(source)
|
|
resp.raise_for_status()
|
|
yaml_content = resp.text
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Failed to download from source: {e}")
|
|
elif source and source.startswith("file://"):
|
|
# Read from local file path
|
|
local_path = source[7:] # strip "file://"
|
|
if not os.path.exists(local_path):
|
|
raise HTTPException(status_code=404, detail=f"Local file not found: {local_path}")
|
|
# Verify the path is within the skills directory
|
|
skills_dir_base = _get_skills_dir(req)
|
|
if not os.path.realpath(local_path).startswith(os.path.realpath(skills_dir_base)):
|
|
raise HTTPException(status_code=400, detail="Local file path must be within the skills directory")
|
|
try:
|
|
with open(local_path, encoding="utf-8") as f:
|
|
yaml_content = f.read()
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Failed to read local file: {e}")
|
|
else:
|
|
# Search GitHub for skills (YAML config files)
|
|
search_query = f"{skill_name} skill config filename:yaml"
|
|
encoded_query = urllib.parse.quote(search_query)
|
|
github_api = f"https://api.github.com/search/code?q={encoded_query}&per_page=5"
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=15) as client:
|
|
gh_resp = await client.get(
|
|
github_api,
|
|
headers={
|
|
"Accept": "application/vnd.github.v3+json",
|
|
"User-Agent": "agentkit",
|
|
},
|
|
)
|
|
gh_data = gh_resp.json()
|
|
except Exception as e:
|
|
raise HTTPException(status_code=502, detail=f"GitHub search failed: {e}")
|
|
|
|
items = gh_data.get("items", [])
|
|
if not items:
|
|
# Fallback: try a simpler search
|
|
search_query2 = f"{skill_name} skill"
|
|
encoded_query2 = urllib.parse.quote(search_query2)
|
|
github_api2 = f"https://api.github.com/search/code?q={encoded_query2}+extension:yaml&per_page=5"
|
|
try:
|
|
async with httpx.AsyncClient(timeout=15) as client:
|
|
gh_resp2 = await client.get(
|
|
github_api2,
|
|
headers={"Accept": "application/vnd.github.v3+json", "User-Agent": "agentkit"},
|
|
)
|
|
items = gh_resp2.json().get("items", [])
|
|
except Exception:
|
|
items = []
|
|
|
|
if not items:
|
|
raise HTTPException(status_code=404, detail=f"No skill found matching '{skill_name}'")
|
|
|
|
# Download the first matching file
|
|
item = items[0]
|
|
raw_url = item.get("html_url", "")
|
|
if raw_url:
|
|
# Validate the URL is from github.com before transforming
|
|
if not raw_url.startswith("https://github.com/"):
|
|
raise HTTPException(status_code=400, detail="Search result URL is not from github.com")
|
|
raw_url = raw_url.replace("github.com", "raw.githubusercontent.com").replace("/blob/", "/")
|
|
else:
|
|
raise HTTPException(status_code=404, detail="Could not construct download URL")
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30, follow_redirects=True, max_redirects=3) as client:
|
|
resp = await client.get(raw_url)
|
|
resp.raise_for_status()
|
|
yaml_content = resp.text
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Failed to download skill: {e}")
|
|
|
|
# Validate YAML content before writing to disk
|
|
_validate_yaml_content(yaml_content)
|
|
|
|
# Save to skills directory (config-driven path)
|
|
skills_dir = _get_skills_dir(req)
|
|
os.makedirs(skills_dir, exist_ok=True)
|
|
file_path = os.path.join(skills_dir, f"{skill_name}.yaml")
|
|
|
|
# Verify resolved path stays within skills_dir (path traversal protection)
|
|
if not os.path.realpath(file_path).startswith(os.path.realpath(skills_dir)):
|
|
raise HTTPException(status_code=400, detail="Invalid path: escapes skills directory")
|
|
|
|
with open(file_path, "w", encoding="utf-8") as f:
|
|
f.write(yaml_content)
|
|
|
|
# Load and register the skill
|
|
registration_ok = False
|
|
try:
|
|
from agentkit.skills.loader import SkillLoader
|
|
loader = SkillLoader(
|
|
skill_registry=skill_registry,
|
|
tool_registry=tool_registry,
|
|
)
|
|
loader.load_from_file(file_path)
|
|
registration_ok = True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to register installed skill: {e}")
|
|
|
|
if not registration_ok:
|
|
# Remove the invalid YAML file and report error
|
|
try:
|
|
os.remove(file_path)
|
|
except Exception:
|
|
pass
|
|
raise HTTPException(status_code=500, detail=f"Skill downloaded but registration failed")
|
|
|
|
return {
|
|
"status": "installed",
|
|
"name": skill_name,
|
|
"path": file_path,
|
|
}
|
|
|
|
|
|
@router.delete("/skills/{name}")
|
|
async def uninstall_skill(name: str, req: Request):
|
|
"""Unregister a skill and optionally remove its YAML file."""
|
|
# Validate name to prevent path traversal
|
|
validated_name = _validate_skill_name(name)
|
|
|
|
skill_registry = req.app.state.skill_registry
|
|
|
|
try:
|
|
skill_registry.get(validated_name)
|
|
except Exception:
|
|
raise HTTPException(status_code=404, detail=f"Skill '{name}' not found")
|
|
|
|
# Remove from registry
|
|
skill_registry.unregister(validated_name)
|
|
|
|
# Remove the YAML file (config-driven path)
|
|
skills_dir = _get_skills_dir(req)
|
|
yaml_path = os.path.join(skills_dir, f"{validated_name}.yaml")
|
|
|
|
# Verify resolved path stays within skills_dir
|
|
if os.path.exists(yaml_path) and os.path.realpath(yaml_path).startswith(os.path.realpath(skills_dir)):
|
|
os.remove(yaml_path)
|
|
|
|
return {"status": "uninstalled", "name": validated_name}
|
|
|
|
|
|
# ---- Pipeline endpoints ----
|
|
|
|
|
|
@router.post("/skills/pipelines", status_code=201)
|
|
async def create_pipeline(request: CreatePipelineRequest, req: Request):
|
|
"""Create and register a SkillPipeline"""
|
|
skill_registry = req.app.state.skill_registry
|
|
|
|
# Validate step definitions
|
|
for i, step in enumerate(request.steps):
|
|
if "skill_name" not in step:
|
|
raise HTTPException(
|
|
status_code=422,
|
|
detail=f"Step {i} missing required field 'skill_name'",
|
|
)
|
|
|
|
pipeline = SkillPipeline(
|
|
name=request.name,
|
|
steps=request.steps,
|
|
skill_registry=skill_registry,
|
|
)
|
|
skill_registry.register_pipeline(pipeline)
|
|
|
|
return {
|
|
"name": pipeline.name,
|
|
"steps": [
|
|
{"skill_name": s["skill_name"], "step_index": i}
|
|
for i, s in enumerate(request.steps)
|
|
],
|
|
}
|
|
|
|
|
|
@router.get("/skills/pipelines")
|
|
async def list_pipelines(req: Request):
|
|
"""List all registered pipelines"""
|
|
skill_registry = req.app.state.skill_registry
|
|
return skill_registry.list_pipelines()
|
|
|
|
|
|
@router.post("/skills/pipelines/{name}/execute")
|
|
async def execute_pipeline(name: str, request: ExecutePipelineRequest, req: Request):
|
|
"""Execute a registered pipeline"""
|
|
skill_registry = req.app.state.skill_registry
|
|
pipeline = skill_registry.get_pipeline(name)
|
|
|
|
if pipeline is None:
|
|
raise HTTPException(status_code=404, detail=f"Pipeline '{name}' not found")
|
|
|
|
try:
|
|
result = await pipeline.execute(input_data=request.input_data)
|
|
except Exception as e:
|
|
logger.error(f"Pipeline execution failed for '{name}': {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail="Pipeline execution failed")
|
|
|
|
return result
|