Added exec.NewCommandExecutor to execute commands with correct interpreter (#1075)

## Changes
Instead of handling command chaining ourselves, we execute passed
commands as-is by storing them, in temp file and passing to correct
interpreter (bash or cmd) based on OS.

Fixes #1065 

## Tests
Added unit tests
This commit is contained in:
Andrew Nester 2023-12-21 16:45:23 +01:00 committed by GitHub
parent 55732bc6ac
commit ac37a592f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 399 additions and 42 deletions

View File

@ -33,7 +33,7 @@ func (m *infer) Apply(ctx context.Context, b *bundle.Bundle) error {
// version=datetime.datetime.utcnow().strftime("%Y%m%d.%H%M%S"),
// ...
//)
artifact.BuildCommand = fmt.Sprintf("%s setup.py bdist_wheel", py)
artifact.BuildCommand = fmt.Sprintf(`"%s" setup.py bdist_wheel`, py)
return nil
}

View File

@ -1,14 +1,12 @@
package config
import (
"bytes"
"context"
"fmt"
"path"
"strings"
"github.com/databricks/cli/bundle/config/paths"
"github.com/databricks/cli/libs/process"
"github.com/databricks/cli/libs/exec"
"github.com/databricks/databricks-sdk-go/service/compute"
)
@ -52,20 +50,11 @@ func (a *Artifact) Build(ctx context.Context) ([]byte, error) {
return nil, fmt.Errorf("no build property defined")
}
out := make([][]byte, 0)
commands := strings.Split(a.BuildCommand, " && ")
for _, command := range commands {
buildParts := strings.Split(command, " ")
var buf bytes.Buffer
_, err := process.Background(ctx, buildParts,
process.WithCombinedOutput(&buf),
process.WithDir(a.Path))
e, err := exec.NewCommandExecutor(a.Path)
if err != nil {
return buf.Bytes(), err
return nil, err
}
out = append(out, buf.Bytes())
}
return bytes.Join(out, []byte{}), nil
return e.Exec(ctx, a.BuildCommand)
}
func (a *Artifact) NormalisePaths() {

View File

@ -0,0 +1,18 @@
package config
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestArtifactBuild(t *testing.T) {
artifact := Artifact{
BuildCommand: "echo 'Hello from build command'",
}
res, err := artifact.Build(context.Background())
assert.NoError(t, err)
assert.NotNil(t, res)
assert.Equal(t, "Hello from build command\n", string(res))
}

View File

@ -5,12 +5,12 @@ import (
"context"
"fmt"
"io"
"os/exec"
"strings"
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/exec"
"github.com/databricks/cli/libs/log"
)
@ -29,7 +29,12 @@ func (m *script) Name() string {
}
func (m *script) Apply(ctx context.Context, b *bundle.Bundle) error {
cmd, out, err := executeHook(ctx, b, m.scriptHook)
executor, err := exec.NewCommandExecutor(b.Config.Path)
if err != nil {
return err
}
cmd, out, err := executeHook(ctx, executor, b, m.scriptHook)
if err != nil {
return err
}
@ -50,32 +55,18 @@ func (m *script) Apply(ctx context.Context, b *bundle.Bundle) error {
return cmd.Wait()
}
func executeHook(ctx context.Context, b *bundle.Bundle, hook config.ScriptHook) (*exec.Cmd, io.Reader, error) {
func executeHook(ctx context.Context, executor *exec.Executor, b *bundle.Bundle, hook config.ScriptHook) (exec.Command, io.Reader, error) {
command := getCommmand(b, hook)
if command == "" {
return nil, nil, nil
}
interpreter, err := findInterpreter()
cmd, err := executor.StartCommand(ctx, string(command))
if err != nil {
return nil, nil, err
}
// TODO: switch to process.Background(...)
cmd := exec.CommandContext(ctx, interpreter, "-c", string(command))
cmd.Dir = b.Config.Path
outPipe, err := cmd.StdoutPipe()
if err != nil {
return nil, nil, err
}
errPipe, err := cmd.StderrPipe()
if err != nil {
return nil, nil, err
}
return cmd, io.MultiReader(outPipe, errPipe), cmd.Start()
return cmd, io.MultiReader(cmd.Stdout(), cmd.Stderr()), nil
}
func getCommmand(b *bundle.Bundle, hook config.ScriptHook) config.Command {
@ -85,8 +76,3 @@ func getCommmand(b *bundle.Bundle, hook config.ScriptHook) config.Command {
return b.Config.Experimental.Scripts[hook]
}
func findInterpreter() (string, error) {
// At the moment we just return 'sh' on all platforms and use it to execute scripts
return "sh", nil
}

View File

@ -8,6 +8,7 @@ import (
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/libs/exec"
"github.com/stretchr/testify/require"
)
@ -21,7 +22,10 @@ func TestExecutesHook(t *testing.T) {
},
},
}
_, out, err := executeHook(context.Background(), b, config.ScriptPreBuild)
executor, err := exec.NewCommandExecutor(b.Config.Path)
require.NoError(t, err)
_, out, err := executeHook(context.Background(), executor, b, config.ScriptPreBuild)
require.NoError(t, err)
reader := bufio.NewReader(out)

101
libs/exec/exec.go Normal file
View File

@ -0,0 +1,101 @@
package exec
import (
"context"
"io"
"os"
osexec "os/exec"
)
type Command interface {
// Wait for command to terminate. It must have been previously started.
Wait() error
// StdinPipe returns a pipe that will be connected to the command's standard input when the command starts.
Stdout() io.ReadCloser
// StderrPipe returns a pipe that will be connected to the command's standard error when the command starts.
Stderr() io.ReadCloser
}
type command struct {
cmd *osexec.Cmd
execContext *execContext
stdout io.ReadCloser
stderr io.ReadCloser
}
func (c *command) Wait() error {
// After the command has finished (cmd.Wait call), remove the temporary script file
defer os.Remove(c.execContext.scriptFile)
err := c.cmd.Wait()
if err != nil {
return err
}
return nil
}
func (c *command) Stdout() io.ReadCloser {
return c.stdout
}
func (c *command) Stderr() io.ReadCloser {
return c.stderr
}
type Executor struct {
interpreter interpreter
dir string
}
func NewCommandExecutor(dir string) (*Executor, error) {
interpreter, err := findInterpreter()
if err != nil {
return nil, err
}
return &Executor{
interpreter: interpreter,
dir: dir,
}, nil
}
func (e *Executor) StartCommand(ctx context.Context, command string) (Command, error) {
ec, err := e.interpreter.prepare(command)
if err != nil {
return nil, err
}
return e.start(ctx, ec)
}
func (e *Executor) start(ctx context.Context, ec *execContext) (Command, error) {
cmd := osexec.CommandContext(ctx, ec.executable, ec.args...)
cmd.Dir = e.dir
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, err
}
stderr, err := cmd.StderrPipe()
if err != nil {
return nil, err
}
return &command{cmd, ec, stdout, stderr}, cmd.Start()
}
func (e *Executor) Exec(ctx context.Context, command string) ([]byte, error) {
cmd, err := e.StartCommand(ctx, command)
if err != nil {
return nil, err
}
res, err := io.ReadAll(io.MultiReader(cmd.Stdout(), cmd.Stderr()))
if err != nil {
return nil, err
}
return res, cmd.Wait()
}

136
libs/exec/exec_test.go Normal file
View File

@ -0,0 +1,136 @@
package exec
import (
"context"
"fmt"
"io"
"runtime"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
func TestExecutorWithSimpleInput(t *testing.T) {
executor, err := NewCommandExecutor(".")
assert.NoError(t, err)
out, err := executor.Exec(context.Background(), "echo 'Hello'")
assert.NoError(t, err)
assert.NotNil(t, out)
assert.Equal(t, "Hello\n", string(out))
}
func TestExecutorWithComplexInput(t *testing.T) {
executor, err := NewCommandExecutor(".")
assert.NoError(t, err)
out, err := executor.Exec(context.Background(), "echo 'Hello' && echo 'World'")
assert.NoError(t, err)
assert.NotNil(t, out)
assert.Equal(t, "Hello\nWorld\n", string(out))
}
func TestExecutorWithInvalidCommand(t *testing.T) {
executor, err := NewCommandExecutor(".")
assert.NoError(t, err)
out, err := executor.Exec(context.Background(), "invalid-command")
assert.Error(t, err)
assert.Contains(t, string(out), "invalid-command: command not found")
}
func TestExecutorWithInvalidCommandWithWindowsLikePath(t *testing.T) {
if runtime.GOOS != "windows" {
t.SkipNow()
}
executor, err := NewCommandExecutor(".")
assert.NoError(t, err)
out, err := executor.Exec(context.Background(), `"C:\Program Files\invalid-command.exe"`)
assert.Error(t, err)
assert.Contains(t, string(out), "C:\\Program Files\\invalid-command.exe: No such file or directory")
}
func TestFindBashInterpreterNonWindows(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
interpreter, err := findBashInterpreter()
assert.NoError(t, err)
assert.NotEmpty(t, interpreter)
e, err := NewCommandExecutor(".")
assert.NoError(t, err)
e.interpreter = interpreter
assert.NoError(t, err)
out, err := e.Exec(context.Background(), `echo "Hello from bash"`)
assert.NoError(t, err)
assert.Equal(t, "Hello from bash\n", string(out))
}
func TestFindCmdInterpreter(t *testing.T) {
if runtime.GOOS != "windows" {
t.SkipNow()
}
interpreter, err := findCmdInterpreter()
assert.NoError(t, err)
assert.NotEmpty(t, interpreter)
e, err := NewCommandExecutor(".")
assert.NoError(t, err)
e.interpreter = interpreter
assert.NoError(t, err)
out, err := e.Exec(context.Background(), `echo "Hello from cmd"`)
assert.NoError(t, err)
assert.Contains(t, string(out), "Hello from cmd")
}
func TestExecutorCleanupsTempFiles(t *testing.T) {
executor, err := NewCommandExecutor(".")
assert.NoError(t, err)
ec, err := executor.interpreter.prepare("echo 'Hello'")
assert.NoError(t, err)
cmd, err := executor.start(context.Background(), ec)
assert.NoError(t, err)
fileName := ec.args[1]
assert.FileExists(t, fileName)
err = cmd.Wait()
assert.NoError(t, err)
assert.NoFileExists(t, fileName)
}
func TestMultipleCommandsRunInParrallel(t *testing.T) {
executor, err := NewCommandExecutor(".")
assert.NoError(t, err)
const count = 5
var wg sync.WaitGroup
for i := 0; i < count; i++ {
wg.Add(1)
cmd, err := executor.StartCommand(context.Background(), fmt.Sprintf("echo 'Hello %d'", i))
go func(cmd Command, i int) {
defer wg.Done()
stdout := cmd.Stdout()
out, err := io.ReadAll(stdout)
assert.NoError(t, err)
err = cmd.Wait()
assert.NoError(t, err)
assert.Equal(t, fmt.Sprintf("Hello %d\n", i), string(out))
}(cmd, i)
assert.NoError(t, err)
}
wg.Wait()
}

123
libs/exec/interpreter.go Normal file
View File

@ -0,0 +1,123 @@
package exec
import (
"errors"
"fmt"
"io"
"os"
osexec "os/exec"
)
type interpreter interface {
prepare(string) (*execContext, error)
}
type execContext struct {
executable string
args []string
scriptFile string
}
type bashInterpreter struct {
executable string
}
func (b *bashInterpreter) prepare(command string) (*execContext, error) {
filename, err := createTempScript(command, ".sh")
if err != nil {
return nil, err
}
return &execContext{
executable: b.executable,
args: []string{"-e", filename},
scriptFile: filename,
}, nil
}
type cmdInterpreter struct {
executable string
}
func (c *cmdInterpreter) prepare(command string) (*execContext, error) {
filename, err := createTempScript(command, ".cmd")
if err != nil {
return nil, err
}
return &execContext{
executable: c.executable,
args: []string{"/D", "/E:ON", "/V:OFF", "/S", "/C", fmt.Sprintf(`CALL %s`, filename)},
scriptFile: filename,
}, nil
}
func findInterpreter() (interpreter, error) {
interpreter, err := findBashInterpreter()
if err != nil {
return nil, err
}
if interpreter != nil {
return interpreter, nil
}
interpreter, err = findCmdInterpreter()
if err != nil {
return nil, err
}
if interpreter != nil {
return interpreter, nil
}
return nil, errors.New("no interpreter found")
}
func findBashInterpreter() (interpreter, error) {
// Lookup for bash executable first (Linux, MacOS, maybe Windows)
out, err := osexec.LookPath("bash")
if err != nil && !errors.Is(err, osexec.ErrNotFound) {
return nil, err
}
// Bash executable is not found, returning early
if out == "" {
return nil, nil
}
return &bashInterpreter{executable: out}, nil
}
func findCmdInterpreter() (interpreter, error) {
// Lookup for CMD executable (Windows)
out, err := osexec.LookPath("cmd")
if err != nil && !errors.Is(err, osexec.ErrNotFound) {
return nil, err
}
// CMD executable is not found, returning early
if out == "" {
return nil, nil
}
return &cmdInterpreter{executable: out}, nil
}
func createTempScript(command string, extension string) (string, error) {
file, err := os.CreateTemp(os.TempDir(), "cli-exec*"+extension)
if err != nil {
return "", err
}
defer file.Close()
_, err = io.WriteString(file, command)
if err != nil {
// Try to remove the file if we failed to write to it
os.Remove(file.Name())
return "", err
}
return file.Name(), nil
}