357 lines
12 KiB
Python
357 lines
12 KiB
Python
"""Tests for Terminal API routes"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.server.app import create_app
|
|
from agentkit.server.routes.terminal import (
|
|
_sessions,
|
|
_check_command_safety,
|
|
_is_dangerous,
|
|
_is_single_command_dangerous,
|
|
_get_or_create_session,
|
|
_cleanup_session,
|
|
)
|
|
from agentkit.skills.registry import SkillRegistry
|
|
from agentkit.tools.registry import ToolRegistry
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def clear_sessions():
|
|
"""Clear the global session store before each test."""
|
|
_sessions.clear()
|
|
yield
|
|
_sessions.clear()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_llm_gateway():
|
|
return LLMGateway()
|
|
|
|
|
|
@pytest.fixture
|
|
def skill_registry():
|
|
return SkillRegistry()
|
|
|
|
|
|
@pytest.fixture
|
|
def tool_registry():
|
|
return ToolRegistry()
|
|
|
|
|
|
@pytest.fixture
|
|
def app(mock_llm_gateway, skill_registry, tool_registry):
|
|
application = create_app(
|
|
llm_gateway=mock_llm_gateway,
|
|
skill_registry=skill_registry,
|
|
tool_registry=tool_registry,
|
|
)
|
|
return application
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
return TestClient(app)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Security helper unit tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestIsSingleCommandDangerous:
|
|
def test_safe_command_ls(self):
|
|
assert _is_single_command_dangerous("ls") is False
|
|
|
|
def test_safe_command_pwd(self):
|
|
assert _is_single_command_dangerous("pwd") is False
|
|
|
|
def test_safe_command_git_status(self):
|
|
assert _is_single_command_dangerous("git status") is False
|
|
|
|
def test_safe_command_echo(self):
|
|
assert _is_single_command_dangerous("echo hello") is False
|
|
|
|
def test_safe_command_cat(self):
|
|
assert _is_single_command_dangerous("cat file.txt") is False
|
|
|
|
def test_dangerous_command_rm(self):
|
|
assert _is_single_command_dangerous("rm file.txt") is True
|
|
|
|
def test_dangerous_command_rm_rf(self):
|
|
assert _is_single_command_dangerous("rm -rf /") is True
|
|
|
|
def test_dangerous_command_mkfs(self):
|
|
assert _is_single_command_dangerous("mkfs.ext4 /dev/sda1") is True
|
|
|
|
def test_dangerous_command_git_push_force(self):
|
|
assert _is_single_command_dangerous("git push --force") is True
|
|
|
|
def test_safe_command_git_add(self):
|
|
# git add is safe even though git push --force is not
|
|
assert _is_single_command_dangerous("git add .") is False
|
|
|
|
def test_dangerous_command_kill_9(self):
|
|
assert _is_single_command_dangerous("kill -9 1234") is True
|
|
|
|
def test_safe_command_kill_normal(self):
|
|
# kill without dangerous flags is safe
|
|
assert _is_single_command_dangerous("kill 1234") is False
|
|
|
|
def test_dangerous_command_pip_uninstall(self):
|
|
assert _is_single_command_dangerous("pip uninstall package") is True
|
|
|
|
def test_safe_command_pip_list(self):
|
|
assert _is_single_command_dangerous("pip list") is False
|
|
|
|
def test_dangerous_command_dd(self):
|
|
assert _is_single_command_dangerous("dd if=/dev/zero of=/dev/sda") is True
|
|
|
|
def test_unknown_command_is_dangerous(self):
|
|
assert _is_single_command_dangerous("unknown_binary") is True
|
|
|
|
def test_empty_command_is_dangerous(self):
|
|
assert _is_single_command_dangerous("") is True
|
|
|
|
def test_safe_command_with_args(self):
|
|
assert _is_single_command_dangerous("ls -la /tmp") is False
|
|
|
|
def test_safe_command_docker_ps(self):
|
|
assert _is_single_command_dangerous("docker ps") is False
|
|
|
|
def test_dangerous_command_docker_rm(self):
|
|
assert _is_single_command_dangerous("docker rm container") is True
|
|
|
|
|
|
class TestIsDangerous:
|
|
def test_safe_single_command(self):
|
|
assert _is_dangerous("ls -la") is False
|
|
|
|
def test_dangerous_single_command(self):
|
|
assert _is_dangerous("rm -rf /") is True
|
|
|
|
def test_safe_pipe(self):
|
|
assert _is_dangerous("ls | grep foo") is False
|
|
|
|
def test_dangerous_pipe(self):
|
|
assert _is_dangerous("ls | rm") is True
|
|
|
|
def test_chain_operator_is_dangerous(self):
|
|
assert _is_dangerous("ls ; rm -rf /") is True
|
|
|
|
def test_and_operator_is_dangerous(self):
|
|
assert _is_dangerous("ls && rm -rf /") is True
|
|
|
|
def test_or_operator_is_dangerous(self):
|
|
assert _is_dangerous("ls || rm -rf /") is True
|
|
|
|
def test_command_substitution_is_dangerous(self):
|
|
assert _is_dangerous("echo $(rm -rf /)") is True
|
|
|
|
def test_redirection_is_dangerous(self):
|
|
assert _is_dangerous("echo foo > /etc/passwd") is True
|
|
|
|
def test_safe_pipe_all_safe(self):
|
|
assert _is_dangerous("cat file.txt | grep pattern | sort") is False
|
|
|
|
|
|
class TestCheckCommandSafety:
|
|
def test_safe_command(self):
|
|
result = _check_command_safety("ls", "test-session")
|
|
assert result["safe"] is True
|
|
|
|
def test_dangerous_command(self):
|
|
result = _check_command_safety("rm -rf /", "test-session")
|
|
assert result["safe"] is False
|
|
assert "reason" in result
|
|
|
|
def test_session_whitelist(self):
|
|
# Create a session with whitelist
|
|
state = _get_or_create_session("test-session")
|
|
state.whitelist.add("rm")
|
|
|
|
result = _check_command_safety("rm file.txt", "test-session")
|
|
assert result["safe"] is True
|
|
|
|
def test_session_whitelist_prefix_match(self):
|
|
state = _get_or_create_session("test-session")
|
|
state.whitelist.add("docker")
|
|
|
|
result = _check_command_safety("docker rm container", "test-session")
|
|
assert result["safe"] is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Session management tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSessionManagement:
|
|
def test_get_or_create_new(self):
|
|
state = _get_or_create_session()
|
|
assert state.session_id is not None
|
|
assert state.cwd != ""
|
|
|
|
def test_get_or_create_with_id(self):
|
|
state = _get_or_create_session("my-session")
|
|
assert state.session_id == "my-session"
|
|
|
|
# Get existing
|
|
state2 = _get_or_create_session("my-session")
|
|
assert state2.session_id == "my-session"
|
|
assert state is state2
|
|
|
|
def test_cleanup_session(self):
|
|
_get_or_create_session("cleanup-test")
|
|
assert "cleanup-test" in _sessions
|
|
|
|
_cleanup_session("cleanup-test")
|
|
assert "cleanup-test" not in _sessions
|
|
|
|
def test_cleanup_nonexistent_session(self):
|
|
# Should not raise
|
|
_cleanup_session("nonexistent")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# REST endpoint tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestExecuteEndpoint:
|
|
def test_execute_safe_command(self, client):
|
|
response = client.post(
|
|
"/api/v1/terminal/execute",
|
|
json={"command": "echo hello"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["exit_code"] == 0
|
|
assert "hello" in data["output"]
|
|
assert data["confirmation_required"] is False
|
|
assert data["session_id"] is not None
|
|
|
|
def test_execute_dangerous_command_returns_confirmation(self, client):
|
|
response = client.post(
|
|
"/api/v1/terminal/execute",
|
|
json={"command": "rm -rf /"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["confirmation_required"] is True
|
|
assert data["exit_code"] == 126
|
|
assert data["reason"] is not None
|
|
|
|
def test_execute_with_session_id(self, client):
|
|
# First request creates session
|
|
resp1 = client.post(
|
|
"/api/v1/terminal/execute",
|
|
json={"command": "echo first", "session_id": "test-session-1"},
|
|
)
|
|
assert resp1.status_code == 200
|
|
assert resp1.json()["session_id"] == "test-session-1"
|
|
|
|
# Second request reuses session
|
|
resp2 = client.post(
|
|
"/api/v1/terminal/execute",
|
|
json={"command": "echo second", "session_id": "test-session-1"},
|
|
)
|
|
assert resp2.status_code == 200
|
|
assert resp2.json()["session_id"] == "test-session-1"
|
|
|
|
def test_execute_empty_command(self, client):
|
|
response = client.post(
|
|
"/api/v1/terminal/execute",
|
|
json={"command": ""},
|
|
)
|
|
# Empty command is dangerous, so confirmation_required
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["confirmation_required"] is True
|
|
|
|
def test_execute_pwd_command(self, client):
|
|
response = client.post(
|
|
"/api/v1/terminal/execute",
|
|
json={"command": "pwd"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["exit_code"] == 0
|
|
assert data["output"].strip() != ""
|
|
|
|
|
|
class TestListSessionsEndpoint:
|
|
def test_list_empty(self, client):
|
|
response = client.get("/api/v1/terminal/sessions")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["sessions"] == []
|
|
|
|
def test_list_after_execute(self, client):
|
|
client.post(
|
|
"/api/v1/terminal/execute",
|
|
json={"command": "echo hello", "session_id": "list-test"},
|
|
)
|
|
response = client.get("/api/v1/terminal/sessions")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data["sessions"]) >= 1
|
|
session_ids = [s["session_id"] for s in data["sessions"]]
|
|
assert "list-test" in session_ids
|
|
|
|
|
|
class TestHistoryEndpoint:
|
|
def test_history_after_execute(self, client):
|
|
client.post(
|
|
"/api/v1/terminal/execute",
|
|
json={"command": "echo hello", "session_id": "history-test"},
|
|
)
|
|
response = client.get("/api/v1/terminal/sessions/history-test/history")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data["history"]) == 1
|
|
assert data["history"][0]["command"] == "echo hello"
|
|
assert data["history"][0]["exit_code"] == 0
|
|
|
|
def test_history_not_found(self, client):
|
|
response = client.get("/api/v1/terminal/sessions/nonexistent/history")
|
|
assert response.status_code == 404
|
|
|
|
def test_history_multiple_commands(self, client):
|
|
client.post(
|
|
"/api/v1/terminal/execute",
|
|
json={"command": "echo first", "session_id": "multi-test"},
|
|
)
|
|
client.post(
|
|
"/api/v1/terminal/execute",
|
|
json={"command": "echo second", "session_id": "multi-test"},
|
|
)
|
|
response = client.get("/api/v1/terminal/sessions/multi-test/history")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data["history"]) == 2
|
|
|
|
|
|
class TestCloseSessionEndpoint:
|
|
def test_close_existing_session(self, client):
|
|
client.post(
|
|
"/api/v1/terminal/execute",
|
|
json={"command": "echo hello", "session_id": "close-test"},
|
|
)
|
|
response = client.delete("/api/v1/terminal/sessions/close-test")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "closed"
|
|
|
|
def test_close_not_found(self, client):
|
|
response = client.delete("/api/v1/terminal/sessions/nonexistent")
|
|
assert response.status_code == 404
|