From ac37a592f1ceb6e1b2b4ab913924830369f25529 Mon Sep 17 00:00:00 2001 From: Andrew Nester Date: Thu, 21 Dec 2023 16:45:23 +0100 Subject: [PATCH] Added exec.NewCommandExecutor to execute commands with correct interpreter (#1075) ## Changes Instead of handling command chaining ourselves, we execute passed commands as-is by storing them, in temp file and passing to correct interpreter (bash or cmd) based on OS. Fixes #1065 ## Tests Added unit tests --- bundle/artifacts/whl/infer.go | 2 +- bundle/config/artifact.go | 21 ++--- bundle/config/artifacts_test.go | 18 +++++ bundle/scripts/scripts.go | 34 +++----- bundle/scripts/scripts_test.go | 6 +- libs/exec/exec.go | 101 ++++++++++++++++++++++++ libs/exec/exec_test.go | 136 ++++++++++++++++++++++++++++++++ libs/exec/interpreter.go | 123 +++++++++++++++++++++++++++++ 8 files changed, 399 insertions(+), 42 deletions(-) create mode 100644 bundle/config/artifacts_test.go create mode 100644 libs/exec/exec.go create mode 100644 libs/exec/exec_test.go create mode 100644 libs/exec/interpreter.go diff --git a/bundle/artifacts/whl/infer.go b/bundle/artifacts/whl/infer.go index dedecc30..dc2b8e23 100644 --- a/bundle/artifacts/whl/infer.go +++ b/bundle/artifacts/whl/infer.go @@ -33,7 +33,7 @@ func (m *infer) Apply(ctx context.Context, b *bundle.Bundle) error { // version=datetime.datetime.utcnow().strftime("%Y%m%d.%H%M%S"), // ... //) - artifact.BuildCommand = fmt.Sprintf("%s setup.py bdist_wheel", py) + artifact.BuildCommand = fmt.Sprintf(`"%s" setup.py bdist_wheel`, py) return nil } diff --git a/bundle/config/artifact.go b/bundle/config/artifact.go index 63ab6c48..2a1a92a1 100644 --- a/bundle/config/artifact.go +++ b/bundle/config/artifact.go @@ -1,14 +1,12 @@ package config import ( - "bytes" "context" "fmt" "path" - "strings" "github.com/databricks/cli/bundle/config/paths" - "github.com/databricks/cli/libs/process" + "github.com/databricks/cli/libs/exec" "github.com/databricks/databricks-sdk-go/service/compute" ) @@ -52,20 +50,11 @@ func (a *Artifact) Build(ctx context.Context) ([]byte, error) { return nil, fmt.Errorf("no build property defined") } - out := make([][]byte, 0) - commands := strings.Split(a.BuildCommand, " && ") - for _, command := range commands { - buildParts := strings.Split(command, " ") - var buf bytes.Buffer - _, err := process.Background(ctx, buildParts, - process.WithCombinedOutput(&buf), - process.WithDir(a.Path)) - if err != nil { - return buf.Bytes(), err - } - out = append(out, buf.Bytes()) + e, err := exec.NewCommandExecutor(a.Path) + if err != nil { + return nil, err } - return bytes.Join(out, []byte{}), nil + return e.Exec(ctx, a.BuildCommand) } func (a *Artifact) NormalisePaths() { diff --git a/bundle/config/artifacts_test.go b/bundle/config/artifacts_test.go new file mode 100644 index 00000000..5fa159fd --- /dev/null +++ b/bundle/config/artifacts_test.go @@ -0,0 +1,18 @@ +package config + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestArtifactBuild(t *testing.T) { + artifact := Artifact{ + BuildCommand: "echo 'Hello from build command'", + } + res, err := artifact.Build(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, res) + assert.Equal(t, "Hello from build command\n", string(res)) +} diff --git a/bundle/scripts/scripts.go b/bundle/scripts/scripts.go index 90c1914f..2f13bc19 100644 --- a/bundle/scripts/scripts.go +++ b/bundle/scripts/scripts.go @@ -5,12 +5,12 @@ import ( "context" "fmt" "io" - "os/exec" "strings" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/exec" "github.com/databricks/cli/libs/log" ) @@ -29,7 +29,12 @@ func (m *script) Name() string { } func (m *script) Apply(ctx context.Context, b *bundle.Bundle) error { - cmd, out, err := executeHook(ctx, b, m.scriptHook) + executor, err := exec.NewCommandExecutor(b.Config.Path) + if err != nil { + return err + } + + cmd, out, err := executeHook(ctx, executor, b, m.scriptHook) if err != nil { return err } @@ -50,32 +55,18 @@ func (m *script) Apply(ctx context.Context, b *bundle.Bundle) error { return cmd.Wait() } -func executeHook(ctx context.Context, b *bundle.Bundle, hook config.ScriptHook) (*exec.Cmd, io.Reader, error) { +func executeHook(ctx context.Context, executor *exec.Executor, b *bundle.Bundle, hook config.ScriptHook) (exec.Command, io.Reader, error) { command := getCommmand(b, hook) if command == "" { return nil, nil, nil } - interpreter, err := findInterpreter() + cmd, err := executor.StartCommand(ctx, string(command)) if err != nil { return nil, nil, err } - // TODO: switch to process.Background(...) - cmd := exec.CommandContext(ctx, interpreter, "-c", string(command)) - cmd.Dir = b.Config.Path - - outPipe, err := cmd.StdoutPipe() - if err != nil { - return nil, nil, err - } - - errPipe, err := cmd.StderrPipe() - if err != nil { - return nil, nil, err - } - - return cmd, io.MultiReader(outPipe, errPipe), cmd.Start() + return cmd, io.MultiReader(cmd.Stdout(), cmd.Stderr()), nil } func getCommmand(b *bundle.Bundle, hook config.ScriptHook) config.Command { @@ -85,8 +76,3 @@ func getCommmand(b *bundle.Bundle, hook config.ScriptHook) config.Command { return b.Config.Experimental.Scripts[hook] } - -func findInterpreter() (string, error) { - // At the moment we just return 'sh' on all platforms and use it to execute scripts - return "sh", nil -} diff --git a/bundle/scripts/scripts_test.go b/bundle/scripts/scripts_test.go index 8b7aa0d1..a8835b59 100644 --- a/bundle/scripts/scripts_test.go +++ b/bundle/scripts/scripts_test.go @@ -8,6 +8,7 @@ import ( "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/libs/exec" "github.com/stretchr/testify/require" ) @@ -21,7 +22,10 @@ func TestExecutesHook(t *testing.T) { }, }, } - _, out, err := executeHook(context.Background(), b, config.ScriptPreBuild) + + executor, err := exec.NewCommandExecutor(b.Config.Path) + require.NoError(t, err) + _, out, err := executeHook(context.Background(), executor, b, config.ScriptPreBuild) require.NoError(t, err) reader := bufio.NewReader(out) diff --git a/libs/exec/exec.go b/libs/exec/exec.go new file mode 100644 index 00000000..7ef6762b --- /dev/null +++ b/libs/exec/exec.go @@ -0,0 +1,101 @@ +package exec + +import ( + "context" + "io" + "os" + osexec "os/exec" +) + +type Command interface { + // Wait for command to terminate. It must have been previously started. + Wait() error + + // StdinPipe returns a pipe that will be connected to the command's standard input when the command starts. + Stdout() io.ReadCloser + + // StderrPipe returns a pipe that will be connected to the command's standard error when the command starts. + Stderr() io.ReadCloser +} + +type command struct { + cmd *osexec.Cmd + execContext *execContext + stdout io.ReadCloser + stderr io.ReadCloser +} + +func (c *command) Wait() error { + // After the command has finished (cmd.Wait call), remove the temporary script file + defer os.Remove(c.execContext.scriptFile) + + err := c.cmd.Wait() + if err != nil { + return err + } + + return nil +} + +func (c *command) Stdout() io.ReadCloser { + return c.stdout +} + +func (c *command) Stderr() io.ReadCloser { + return c.stderr +} + +type Executor struct { + interpreter interpreter + dir string +} + +func NewCommandExecutor(dir string) (*Executor, error) { + interpreter, err := findInterpreter() + if err != nil { + return nil, err + } + return &Executor{ + interpreter: interpreter, + dir: dir, + }, nil +} + +func (e *Executor) StartCommand(ctx context.Context, command string) (Command, error) { + ec, err := e.interpreter.prepare(command) + if err != nil { + return nil, err + } + return e.start(ctx, ec) +} + +func (e *Executor) start(ctx context.Context, ec *execContext) (Command, error) { + cmd := osexec.CommandContext(ctx, ec.executable, ec.args...) + cmd.Dir = e.dir + + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + + stderr, err := cmd.StderrPipe() + if err != nil { + return nil, err + } + + return &command{cmd, ec, stdout, stderr}, cmd.Start() +} + +func (e *Executor) Exec(ctx context.Context, command string) ([]byte, error) { + cmd, err := e.StartCommand(ctx, command) + if err != nil { + return nil, err + } + + res, err := io.ReadAll(io.MultiReader(cmd.Stdout(), cmd.Stderr())) + if err != nil { + return nil, err + } + + return res, cmd.Wait() +} diff --git a/libs/exec/exec_test.go b/libs/exec/exec_test.go new file mode 100644 index 00000000..a1d8d6ff --- /dev/null +++ b/libs/exec/exec_test.go @@ -0,0 +1,136 @@ +package exec + +import ( + "context" + "fmt" + "io" + "runtime" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExecutorWithSimpleInput(t *testing.T) { + executor, err := NewCommandExecutor(".") + assert.NoError(t, err) + out, err := executor.Exec(context.Background(), "echo 'Hello'") + assert.NoError(t, err) + assert.NotNil(t, out) + assert.Equal(t, "Hello\n", string(out)) +} + +func TestExecutorWithComplexInput(t *testing.T) { + executor, err := NewCommandExecutor(".") + assert.NoError(t, err) + out, err := executor.Exec(context.Background(), "echo 'Hello' && echo 'World'") + assert.NoError(t, err) + assert.NotNil(t, out) + assert.Equal(t, "Hello\nWorld\n", string(out)) +} + +func TestExecutorWithInvalidCommand(t *testing.T) { + executor, err := NewCommandExecutor(".") + assert.NoError(t, err) + out, err := executor.Exec(context.Background(), "invalid-command") + assert.Error(t, err) + assert.Contains(t, string(out), "invalid-command: command not found") +} + +func TestExecutorWithInvalidCommandWithWindowsLikePath(t *testing.T) { + if runtime.GOOS != "windows" { + t.SkipNow() + } + + executor, err := NewCommandExecutor(".") + assert.NoError(t, err) + out, err := executor.Exec(context.Background(), `"C:\Program Files\invalid-command.exe"`) + assert.Error(t, err) + assert.Contains(t, string(out), "C:\\Program Files\\invalid-command.exe: No such file or directory") +} + +func TestFindBashInterpreterNonWindows(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + + interpreter, err := findBashInterpreter() + assert.NoError(t, err) + assert.NotEmpty(t, interpreter) + + e, err := NewCommandExecutor(".") + assert.NoError(t, err) + e.interpreter = interpreter + + assert.NoError(t, err) + out, err := e.Exec(context.Background(), `echo "Hello from bash"`) + assert.NoError(t, err) + + assert.Equal(t, "Hello from bash\n", string(out)) +} + +func TestFindCmdInterpreter(t *testing.T) { + if runtime.GOOS != "windows" { + t.SkipNow() + } + + interpreter, err := findCmdInterpreter() + assert.NoError(t, err) + assert.NotEmpty(t, interpreter) + + e, err := NewCommandExecutor(".") + assert.NoError(t, err) + e.interpreter = interpreter + + assert.NoError(t, err) + out, err := e.Exec(context.Background(), `echo "Hello from cmd"`) + assert.NoError(t, err) + + assert.Contains(t, string(out), "Hello from cmd") +} + +func TestExecutorCleanupsTempFiles(t *testing.T) { + executor, err := NewCommandExecutor(".") + assert.NoError(t, err) + + ec, err := executor.interpreter.prepare("echo 'Hello'") + assert.NoError(t, err) + + cmd, err := executor.start(context.Background(), ec) + assert.NoError(t, err) + + fileName := ec.args[1] + assert.FileExists(t, fileName) + + err = cmd.Wait() + assert.NoError(t, err) + assert.NoFileExists(t, fileName) +} + +func TestMultipleCommandsRunInParrallel(t *testing.T) { + executor, err := NewCommandExecutor(".") + assert.NoError(t, err) + + const count = 5 + var wg sync.WaitGroup + + for i := 0; i < count; i++ { + wg.Add(1) + cmd, err := executor.StartCommand(context.Background(), fmt.Sprintf("echo 'Hello %d'", i)) + go func(cmd Command, i int) { + defer wg.Done() + + stdout := cmd.Stdout() + out, err := io.ReadAll(stdout) + assert.NoError(t, err) + + err = cmd.Wait() + assert.NoError(t, err) + + assert.Equal(t, fmt.Sprintf("Hello %d\n", i), string(out)) + }(cmd, i) + assert.NoError(t, err) + } + + wg.Wait() +} diff --git a/libs/exec/interpreter.go b/libs/exec/interpreter.go new file mode 100644 index 00000000..e600e47f --- /dev/null +++ b/libs/exec/interpreter.go @@ -0,0 +1,123 @@ +package exec + +import ( + "errors" + "fmt" + "io" + "os" + osexec "os/exec" +) + +type interpreter interface { + prepare(string) (*execContext, error) +} + +type execContext struct { + executable string + args []string + scriptFile string +} + +type bashInterpreter struct { + executable string +} + +func (b *bashInterpreter) prepare(command string) (*execContext, error) { + filename, err := createTempScript(command, ".sh") + if err != nil { + return nil, err + } + + return &execContext{ + executable: b.executable, + args: []string{"-e", filename}, + scriptFile: filename, + }, nil +} + +type cmdInterpreter struct { + executable string +} + +func (c *cmdInterpreter) prepare(command string) (*execContext, error) { + filename, err := createTempScript(command, ".cmd") + if err != nil { + return nil, err + } + + return &execContext{ + executable: c.executable, + args: []string{"/D", "/E:ON", "/V:OFF", "/S", "/C", fmt.Sprintf(`CALL %s`, filename)}, + scriptFile: filename, + }, nil +} + +func findInterpreter() (interpreter, error) { + interpreter, err := findBashInterpreter() + if err != nil { + return nil, err + } + + if interpreter != nil { + return interpreter, nil + } + + interpreter, err = findCmdInterpreter() + if err != nil { + return nil, err + } + + if interpreter != nil { + return interpreter, nil + } + + return nil, errors.New("no interpreter found") +} + +func findBashInterpreter() (interpreter, error) { + // Lookup for bash executable first (Linux, MacOS, maybe Windows) + out, err := osexec.LookPath("bash") + if err != nil && !errors.Is(err, osexec.ErrNotFound) { + return nil, err + } + + // Bash executable is not found, returning early + if out == "" { + return nil, nil + } + + return &bashInterpreter{executable: out}, nil +} + +func findCmdInterpreter() (interpreter, error) { + // Lookup for CMD executable (Windows) + out, err := osexec.LookPath("cmd") + if err != nil && !errors.Is(err, osexec.ErrNotFound) { + return nil, err + } + + // CMD executable is not found, returning early + if out == "" { + return nil, nil + } + + return &cmdInterpreter{executable: out}, nil +} + +func createTempScript(command string, extension string) (string, error) { + file, err := os.CreateTemp(os.TempDir(), "cli-exec*"+extension) + if err != nil { + return "", err + } + + defer file.Close() + + _, err = io.WriteString(file, command) + if err != nil { + // Try to remove the file if we failed to write to it + os.Remove(file.Name()) + return "", err + } + + return file.Name(), nil +}