374 lines
12 KiB
Python
374 lines
12 KiB
Python
"""Tests for Settings API routes"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import unittest.mock
|
|
|
|
import pytest
|
|
import yaml
|
|
from fastapi.testclient import TestClient
|
|
|
|
from agentkit.server.app import create_app
|
|
from agentkit.server.config import ServerConfig
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
def config_file(tmp_path):
|
|
"""Create a temporary agentkit.yaml config file."""
|
|
config_data = {
|
|
"server": {
|
|
"host": "0.0.0.0",
|
|
"port": 8001,
|
|
"workers": 1,
|
|
"rate_limit": 60,
|
|
"cors_origins": ["*"],
|
|
},
|
|
"llm": {
|
|
"providers": {
|
|
"openai": {
|
|
"type": "openai",
|
|
"api_key": "sk-test-12345678",
|
|
"base_url": "https://api.openai.com/v1",
|
|
"models": {"gpt-4o": {}},
|
|
"max_tokens": 4096,
|
|
"timeout": 120.0,
|
|
},
|
|
"anthropic": {
|
|
"type": "anthropic",
|
|
"api_key": "sk-ant-test-abcd1234",
|
|
"base_url": "https://api.anthropic.com",
|
|
"models": {"claude-sonnet-4-20250514": {}},
|
|
"max_tokens": 4096,
|
|
"timeout": 120.0,
|
|
},
|
|
},
|
|
"model_aliases": {"gpt4": "openai/gpt-4o"},
|
|
"fallbacks": {},
|
|
},
|
|
"skills": {
|
|
"paths": ["/tmp/skills"],
|
|
"auto_discover": True,
|
|
},
|
|
"logging": {
|
|
"level": "INFO",
|
|
"format": "text",
|
|
},
|
|
"memory": {
|
|
"semantic": {
|
|
"enabled": True,
|
|
"base_url": "http://localhost:8080",
|
|
"knowledge_base_ids": ["kb-1", "kb-2"],
|
|
"search_mode": "standard",
|
|
"top_k": 5,
|
|
},
|
|
},
|
|
}
|
|
config_path = tmp_path / "agentkit.yaml"
|
|
with open(config_path, "w", encoding="utf-8") as f:
|
|
yaml.dump(config_data, f, default_flow_style=False, allow_unicode=True)
|
|
return str(config_path)
|
|
|
|
|
|
@pytest.fixture
|
|
def server_config(config_file):
|
|
"""Create ServerConfig from the temp config file."""
|
|
return ServerConfig.from_yaml(config_file)
|
|
|
|
|
|
@pytest.fixture
|
|
def app(server_config):
|
|
return create_app(
|
|
server_config=server_config,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
return TestClient(app)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /settings/llm
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetLlmSettings:
|
|
def test_returns_llm_config(self, client):
|
|
response = client.get("/api/v1/settings/llm")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "providers" in data
|
|
assert "model_aliases" in data
|
|
assert "fallbacks" in data
|
|
|
|
def test_api_keys_are_masked(self, client):
|
|
response = client.get("/api/v1/settings/llm")
|
|
data = response.json()
|
|
for provider in data["providers"]:
|
|
# API keys should be masked: ****xxxx
|
|
api_key = provider["api_key"]
|
|
if api_key:
|
|
assert api_key.startswith("****")
|
|
# Should show only last 4 chars
|
|
assert len(api_key) <= 8 # "****" + 4 chars
|
|
|
|
def test_providers_have_correct_fields(self, client):
|
|
response = client.get("/api/v1/settings/llm")
|
|
data = response.json()
|
|
assert len(data["providers"]) >= 1
|
|
provider = data["providers"][0]
|
|
assert "name" in provider
|
|
assert "type" in provider
|
|
assert "api_key" in provider
|
|
assert "base_url" in provider
|
|
assert "models" in provider
|
|
assert "max_tokens" in provider
|
|
assert "timeout" in provider
|
|
|
|
def test_model_aliases_returned(self, client):
|
|
"""model_aliases are built from model alias fields in providers."""
|
|
response = client.get("/api/v1/settings/llm")
|
|
data = response.json()
|
|
# model_aliases is populated from model "alias" fields, not the top-level key
|
|
assert isinstance(data["model_aliases"], dict)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PUT /settings/llm
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestUpdateLlmSettings:
|
|
def test_update_llm_config(self, client):
|
|
response = client.put(
|
|
"/api/v1/settings/llm",
|
|
json={
|
|
"providers": [
|
|
{
|
|
"name": "openai",
|
|
"type": "openai",
|
|
"api_key": "sk-new-key-9999",
|
|
"base_url": "https://api.openai.com/v1",
|
|
"models": {"gpt-4o-mini": {}},
|
|
"max_tokens": 2048,
|
|
"timeout": 60.0,
|
|
}
|
|
],
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "providers" in data
|
|
|
|
def test_update_preserves_existing_key_when_masked(self, client, config_file):
|
|
"""When user sends back a masked key (****xxxx), the real key should be preserved."""
|
|
response = client.put(
|
|
"/api/v1/settings/llm",
|
|
json={
|
|
"providers": [
|
|
{
|
|
"name": "openai",
|
|
"type": "openai",
|
|
"api_key": "****5678", # masked key
|
|
"base_url": "https://api.openai.com/v1",
|
|
}
|
|
],
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
# Verify the config file still has the original key
|
|
with open(config_file, encoding="utf-8") as f:
|
|
saved = yaml.safe_load(f)
|
|
# The original key should be preserved, not replaced with the masked value
|
|
assert saved["llm"]["providers"]["openai"]["api_key"] == "sk-test-12345678"
|
|
|
|
def test_update_writes_to_config_file(self, client, config_file):
|
|
response = client.put(
|
|
"/api/v1/settings/llm",
|
|
json={
|
|
"model_aliases": {"mini": "openai/gpt-4o-mini"},
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
with open(config_file, encoding="utf-8") as f:
|
|
saved = yaml.safe_load(f)
|
|
assert saved["llm"]["model_aliases"]["mini"] == "openai/gpt-4o-mini"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /settings/skills
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetSkillsSettings:
|
|
def test_returns_skills_config(self, client):
|
|
response = client.get("/api/v1/settings/skills")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "paths" in data
|
|
assert "auto_discover" in data
|
|
|
|
def test_skills_paths_returned(self, client):
|
|
response = client.get("/api/v1/settings/skills")
|
|
data = response.json()
|
|
assert "/tmp/skills" in data["paths"]
|
|
assert data["auto_discover"] is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PUT /settings/skills
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestUpdateSkillsSettings:
|
|
def test_update_skills_paths(self, client, config_file):
|
|
response = client.put(
|
|
"/api/v1/settings/skills",
|
|
json={
|
|
"paths": ["/tmp/skills", "/tmp/new-skills"],
|
|
"auto_discover": False,
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
with open(config_file, encoding="utf-8") as f:
|
|
saved = yaml.safe_load(f)
|
|
assert "/tmp/new-skills" in saved["skills"]["paths"]
|
|
assert saved["skills"]["auto_discover"] is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /settings/kb
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetKbSettings:
|
|
def test_returns_kb_config(self, client):
|
|
response = client.get("/api/v1/settings/kb")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "memory" in data
|
|
|
|
def test_kb_memory_config_returned(self, client):
|
|
response = client.get("/api/v1/settings/kb")
|
|
data = response.json()
|
|
assert "semantic" in data["memory"]
|
|
assert data["memory"]["semantic"]["enabled"] is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PUT /settings/kb
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestUpdateKbSettings:
|
|
def test_update_kb_config(self, client, config_file):
|
|
response = client.put(
|
|
"/api/v1/settings/kb",
|
|
json={
|
|
"memory": {
|
|
"semantic": {
|
|
"enabled": True,
|
|
"base_url": "http://localhost:9090",
|
|
"knowledge_base_ids": ["kb-new"],
|
|
"search_mode": "rerank",
|
|
"top_k": 10,
|
|
},
|
|
},
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
with open(config_file, encoding="utf-8") as f:
|
|
saved = yaml.safe_load(f)
|
|
assert saved["memory"]["semantic"]["base_url"] == "http://localhost:9090"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /settings/general
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetGeneralSettings:
|
|
def test_returns_general_config(self, client):
|
|
response = client.get("/api/v1/settings/general")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "host" in data
|
|
assert "port" in data
|
|
assert "workers" in data
|
|
assert "log_level" in data
|
|
assert "log_format" in data
|
|
assert "rate_limit" in data
|
|
assert "cors_origins" in data
|
|
|
|
def test_general_config_values(self, client):
|
|
response = client.get("/api/v1/settings/general")
|
|
data = response.json()
|
|
assert data["host"] == "0.0.0.0"
|
|
assert data["port"] == 8001
|
|
assert data["log_level"] == "INFO"
|
|
assert data["rate_limit"] == 60
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PUT /settings/general
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestUpdateGeneralSettings:
|
|
def test_update_general_config(self, client, config_file):
|
|
response = client.put(
|
|
"/api/v1/settings/general",
|
|
json={
|
|
"log_level": "DEBUG",
|
|
"rate_limit": 100,
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
with open(config_file, encoding="utf-8") as f:
|
|
saved = yaml.safe_load(f)
|
|
assert saved["logging"]["level"] == "DEBUG"
|
|
assert saved["server"]["rate_limit"] == 100
|
|
|
|
def test_update_cors_origins(self, client, config_file):
|
|
response = client.put(
|
|
"/api/v1/settings/general",
|
|
json={
|
|
"cors_origins": ["http://localhost:3000", "http://localhost:5173"],
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
with open(config_file, encoding="utf-8") as f:
|
|
saved = yaml.safe_load(f)
|
|
assert "http://localhost:3000" in saved["server"]["cors_origins"]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# No config file path available
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestNoConfigPath:
|
|
def test_put_without_config_path_returns_400(self):
|
|
"""When server_config has no _config_path, PUT should return 400."""
|
|
from agentkit.llm.gateway import LLMGateway
|
|
|
|
# Mock os.environ to prevent auto-discovery of agentkit.yaml in CWD
|
|
with unittest.mock.patch.dict(os.environ, {}, clear=True):
|
|
# Also ensure no agentkit.yaml is found in CWD
|
|
with unittest.mock.patch("pathlib.Path.exists", return_value=False):
|
|
app = create_app(llm_gateway=LLMGateway())
|
|
# server_config is None in this case
|
|
client = TestClient(app)
|
|
|
|
response = client.put(
|
|
"/api/v1/settings/llm",
|
|
json={"providers": [{"name": "test", "type": "openai"}]},
|
|
)
|
|
assert response.status_code == 400
|