From 7171874db0a42d81e89955d518367c185b21c1c6 Mon Sep 17 00:00:00 2001 From: Serge Smertin <259697+nfx@users.noreply.github.com> Date: Wed, 27 Sep 2023 11:04:44 +0200 Subject: [PATCH] Added `process.Background()` and `process.Forwarded()` (#804) ## Changes This PR adds higher-level wrappers for calling subprocesses. One of the steps to get https://github.com/databricks/cli/pull/637 in, as previously discussed. The reason to add `process.Forwarded()` is to proxy Python's `input()` calls from a child process seamlessly. Another use-case is plugging in `less` as a pager for the list results. ## Tests `make test` --- bundle/config/artifact.go | 13 ++--- bundle/scripts/scripts.go | 1 + libs/env/context.go | 19 +++++++ libs/env/context_test.go | 8 +++ libs/git/clone.go | 21 +++----- libs/process/background.go | 59 +++++++++++++++++++++ libs/process/background_test.go | 91 +++++++++++++++++++++++++++++++++ libs/process/forwarded.go | 43 ++++++++++++++++ libs/process/forwarded_test.go | 43 ++++++++++++++++ libs/process/opts.go | 57 +++++++++++++++++++++ libs/process/opts_test.go | 47 +++++++++++++++++ python/runner.go | 6 ++- python/runner_test.go | 6 +-- 13 files changed, 390 insertions(+), 24 deletions(-) create mode 100644 libs/process/background.go create mode 100644 libs/process/background_test.go create mode 100644 libs/process/forwarded.go create mode 100644 libs/process/forwarded_test.go create mode 100644 libs/process/opts.go create mode 100644 libs/process/opts_test.go diff --git a/bundle/config/artifact.go b/bundle/config/artifact.go index d7048a02..755116eb 100644 --- a/bundle/config/artifact.go +++ b/bundle/config/artifact.go @@ -4,11 +4,11 @@ import ( "bytes" "context" "fmt" - "os/exec" "path" "strings" "github.com/databricks/cli/bundle/config/paths" + "github.com/databricks/cli/libs/process" "github.com/databricks/databricks-sdk-go/service/compute" ) @@ -56,13 +56,14 @@ func (a *Artifact) Build(ctx context.Context) ([]byte, error) { commands := strings.Split(a.BuildCommand, " && ") for _, command := range commands { buildParts := strings.Split(command, " ") - cmd := exec.CommandContext(ctx, buildParts[0], buildParts[1:]...) - cmd.Dir = a.Path - res, err := cmd.CombinedOutput() + var buf bytes.Buffer + _, err := process.Background(ctx, buildParts, + process.WithCombinedOutput(&buf), + process.WithDir(a.Path)) if err != nil { - return res, err + return buf.Bytes(), err } - out = append(out, res) + out = append(out, buf.Bytes()) } return bytes.Join(out, []byte{}), nil } diff --git a/bundle/scripts/scripts.go b/bundle/scripts/scripts.go index 1a8a471c..90c1914f 100644 --- a/bundle/scripts/scripts.go +++ b/bundle/scripts/scripts.go @@ -61,6 +61,7 @@ func executeHook(ctx context.Context, b *bundle.Bundle, hook config.ScriptHook) return nil, nil, err } + // TODO: switch to process.Background(...) cmd := exec.CommandContext(ctx, interpreter, "-c", string(command)) cmd.Dir = b.Config.Path diff --git a/libs/env/context.go b/libs/env/context.go index cf04c1ec..bbe294d7 100644 --- a/libs/env/context.go +++ b/libs/env/context.go @@ -3,6 +3,7 @@ package env import ( "context" "os" + "strings" ) var envContextKey int @@ -61,3 +62,21 @@ func Set(ctx context.Context, key, value string) context.Context { m[key] = value return setMap(ctx, m) } + +// All returns environment variables that are defined in both os.Environ +// and this package. `env.Set(ctx, x, y)` will override x from os.Environ. +func All(ctx context.Context) map[string]string { + m := map[string]string{} + for _, line := range os.Environ() { + split := strings.SplitN(line, "=", 2) + if len(split) != 2 { + continue + } + m[split[0]] = split[1] + } + // override existing environment variables with the ones we set + for k, v := range getMap(ctx) { + m[k] = v + } + return m +} diff --git a/libs/env/context_test.go b/libs/env/context_test.go index 9ff19459..39553448 100644 --- a/libs/env/context_test.go +++ b/libs/env/context_test.go @@ -38,4 +38,12 @@ func TestContext(t *testing.T) { assert.Equal(t, "qux", Get(ctx2, "FOO")) assert.Equal(t, "baz", Get(ctx1, "FOO")) assert.Equal(t, "bar", Get(ctx0, "FOO")) + + ctx3 := Set(ctx2, "BAR", "x=y") + + all := All(ctx3) + assert.NotNil(t, all) + assert.Equal(t, "qux", all["FOO"]) + assert.Equal(t, "x=y", all["BAR"]) + assert.NotEmpty(t, all["PATH"]) } diff --git a/libs/git/clone.go b/libs/git/clone.go index af7ffa4b..e7d001cd 100644 --- a/libs/git/clone.go +++ b/libs/git/clone.go @@ -1,13 +1,14 @@ package git import ( - "bytes" "context" "errors" "fmt" "os/exec" "regexp" "strings" + + "github.com/databricks/cli/libs/process" ) // source: https://stackoverflow.com/questions/59081778/rules-for-special-characters-in-github-repository-name @@ -42,24 +43,18 @@ func (opts cloneOptions) args() []string { } func (opts cloneOptions) clone(ctx context.Context) error { - cmd := exec.CommandContext(ctx, "git", opts.args()...) - var cmdErr bytes.Buffer - cmd.Stderr = &cmdErr - - // start git clone - err := cmd.Start() + // start and wait for git clone to complete + _, err := process.Background(ctx, append([]string{"git"}, opts.args()...)) if errors.Is(err, exec.ErrNotFound) { return fmt.Errorf("please install git CLI to clone a repository: %w", err) } + var processErr *process.ProcessError + if errors.As(err, &processErr) { + return fmt.Errorf("git clone failed: %w. %s", err, processErr.Stderr) + } if err != nil { return fmt.Errorf("git clone failed: %w", err) } - - // wait for git clone to complete - err = cmd.Wait() - if err != nil { - return fmt.Errorf("git clone failed: %w. %s", err, cmdErr.String()) - } return nil } diff --git a/libs/process/background.go b/libs/process/background.go new file mode 100644 index 00000000..26178a1d --- /dev/null +++ b/libs/process/background.go @@ -0,0 +1,59 @@ +package process + +import ( + "bytes" + "context" + "fmt" + "os/exec" + "strings" + + "github.com/databricks/cli/libs/env" + "github.com/databricks/cli/libs/log" +) + +type ProcessError struct { + Command string + Err error + Stdout string + Stderr string +} + +func (perr *ProcessError) Unwrap() error { + return perr.Err +} + +func (perr *ProcessError) Error() string { + return fmt.Sprintf("%s: %s", perr.Command, perr.Err) +} + +func Background(ctx context.Context, args []string, opts ...execOption) (string, error) { + commandStr := strings.Join(args, " ") + log.Debugf(ctx, "running: %s", commandStr) + cmd := exec.CommandContext(ctx, args[0], args[1:]...) + stdout := bytes.Buffer{} + stderr := bytes.Buffer{} + // For background processes, there's no standard input + cmd.Stdin = nil + cmd.Stdout = &stdout + cmd.Stderr = &stderr + // we pull the env through lib/env such that we can run + // parallel tests with anything using libs/process. + for k, v := range env.All(ctx) { + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) + } + for _, o := range opts { + err := o(ctx, cmd) + if err != nil { + return "", err + } + } + if err := cmd.Run(); err != nil { + return stdout.String(), &ProcessError{ + Err: err, + Command: commandStr, + Stdout: stdout.String(), + Stderr: stderr.String(), + } + } + return stdout.String(), nil +} diff --git a/libs/process/background_test.go b/libs/process/background_test.go new file mode 100644 index 00000000..94f7e881 --- /dev/null +++ b/libs/process/background_test.go @@ -0,0 +1,91 @@ +package process + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBackgroundUnwrapsNotFound(t *testing.T) { + ctx := context.Background() + _, err := Background(ctx, []string{"/bin/meeecho", "1"}) + assert.ErrorIs(t, err, os.ErrNotExist) +} + +func TestBackground(t *testing.T) { + ctx := context.Background() + res, err := Background(ctx, []string{"echo", "1"}, WithDir("/")) + assert.NoError(t, err) + assert.Equal(t, "1", strings.TrimSpace(res)) +} + +func TestBackgroundOnlyStdoutGetsoutOnSuccess(t *testing.T) { + ctx := context.Background() + res, err := Background(ctx, []string{ + "python3", "-c", "import sys; sys.stderr.write('1'); sys.stdout.write('2')", + }) + assert.NoError(t, err) + assert.Equal(t, "2", res) +} + +func TestBackgroundCombinedOutput(t *testing.T) { + ctx := context.Background() + buf := bytes.Buffer{} + res, err := Background(ctx, []string{ + "python3", "-c", "import sys, time; " + + `sys.stderr.write("1\n"); sys.stderr.flush(); ` + + "time.sleep(0.001); " + + "print('2', flush=True); sys.stdout.flush(); " + + "time.sleep(0.001)", + }, WithCombinedOutput(&buf)) + assert.NoError(t, err) + assert.Equal(t, "2", strings.TrimSpace(res)) + assert.Equal(t, "1\n2\n", strings.ReplaceAll(buf.String(), "\r", "")) +} + +func TestBackgroundCombinedOutputFailure(t *testing.T) { + ctx := context.Background() + buf := bytes.Buffer{} + res, err := Background(ctx, []string{ + "python3", "-c", "import sys, time; " + + `sys.stderr.write("1\n"); sys.stderr.flush(); ` + + "time.sleep(0.001); " + + "print('2', flush=True); sys.stdout.flush(); " + + "time.sleep(0.001); " + + "sys.exit(42)", + }, WithCombinedOutput(&buf)) + var processErr *ProcessError + if assert.ErrorAs(t, err, &processErr) { + assert.Equal(t, "1", strings.TrimSpace(processErr.Stderr)) + assert.Equal(t, "2", strings.TrimSpace(processErr.Stdout)) + } + assert.Equal(t, "2", strings.TrimSpace(res)) + assert.Equal(t, "1\n2\n", strings.ReplaceAll(buf.String(), "\r", "")) +} + +func TestBackgroundNoStdin(t *testing.T) { + ctx := context.Background() + res, err := Background(ctx, []string{"cat"}) + assert.NoError(t, err) + assert.Equal(t, "", res) +} + +func TestBackgroundFails(t *testing.T) { + ctx := context.Background() + _, err := Background(ctx, []string{"ls", "/dev/null/x"}) + assert.NotNil(t, err) +} + +func TestBackgroundFailsOnOption(t *testing.T) { + ctx := context.Background() + _, err := Background(ctx, []string{"ls", "/dev/null/x"}, func(_ context.Context, c *exec.Cmd) error { + return fmt.Errorf("nope") + }) + assert.EqualError(t, err, "nope") +} diff --git a/libs/process/forwarded.go b/libs/process/forwarded.go new file mode 100644 index 00000000..df3c2dbd --- /dev/null +++ b/libs/process/forwarded.go @@ -0,0 +1,43 @@ +package process + +import ( + "context" + "fmt" + "io" + "os/exec" + "strings" + + "github.com/databricks/cli/libs/env" + "github.com/databricks/cli/libs/log" +) + +func Forwarded(ctx context.Context, args []string, src io.Reader, outWriter, errWriter io.Writer, opts ...execOption) error { + commandStr := strings.Join(args, " ") + log.Debugf(ctx, "starting: %s", commandStr) + cmd := exec.CommandContext(ctx, args[0], args[1:]...) + + // empirical tests showed buffered copies being more responsive + cmd.Stdout = outWriter + cmd.Stderr = errWriter + cmd.Stdin = src + // we pull the env through lib/env such that we can run + // parallel tests with anything using libs/process. + for k, v := range env.All(ctx) { + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) + } + + // apply common options + for _, o := range opts { + err := o(ctx, cmd) + if err != nil { + return err + } + } + + err := cmd.Start() + if err != nil { + return err + } + + return cmd.Wait() +} diff --git a/libs/process/forwarded_test.go b/libs/process/forwarded_test.go new file mode 100644 index 00000000..ddb79818 --- /dev/null +++ b/libs/process/forwarded_test.go @@ -0,0 +1,43 @@ +package process + +import ( + "bytes" + "context" + "os/exec" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestForwarded(t *testing.T) { + ctx := context.Background() + var buf bytes.Buffer + err := Forwarded(ctx, []string{ + "python3", "-c", "print(input('input: '))", + }, strings.NewReader("abc\n"), &buf, &buf) + assert.NoError(t, err) + + assert.Equal(t, "input: abc", strings.TrimSpace(buf.String())) +} + +func TestForwardedFails(t *testing.T) { + ctx := context.Background() + var buf bytes.Buffer + err := Forwarded(ctx, []string{ + "_non_existent_", + }, strings.NewReader("abc\n"), &buf, &buf) + assert.NotNil(t, err) +} + +func TestForwardedFailsOnStdinPipe(t *testing.T) { + ctx := context.Background() + var buf bytes.Buffer + err := Forwarded(ctx, []string{ + "_non_existent_", + }, strings.NewReader("abc\n"), &buf, &buf, func(_ context.Context, c *exec.Cmd) error { + c.Stdin = strings.NewReader("x") + return nil + }) + assert.NotNil(t, err) +} diff --git a/libs/process/opts.go b/libs/process/opts.go new file mode 100644 index 00000000..e201c666 --- /dev/null +++ b/libs/process/opts.go @@ -0,0 +1,57 @@ +package process + +import ( + "bytes" + "context" + "fmt" + "io" + "os/exec" +) + +type execOption func(context.Context, *exec.Cmd) error + +func WithEnv(key, value string) execOption { + return func(ctx context.Context, c *exec.Cmd) error { + v := fmt.Sprintf("%s=%s", key, value) + c.Env = append(c.Env, v) + return nil + } +} + +func WithEnvs(envs map[string]string) execOption { + return func(ctx context.Context, c *exec.Cmd) error { + for k, v := range envs { + err := WithEnv(k, v)(ctx, c) + if err != nil { + return err + } + } + return nil + } +} + +func WithDir(dir string) execOption { + return func(_ context.Context, c *exec.Cmd) error { + c.Dir = dir + return nil + } +} + +func WithStdoutPipe(dst *io.ReadCloser) execOption { + return func(_ context.Context, c *exec.Cmd) error { + outPipe, err := c.StdoutPipe() + if err != nil { + return err + } + *dst = outPipe + return nil + } +} + +func WithCombinedOutput(buf *bytes.Buffer) execOption { + return func(_ context.Context, c *exec.Cmd) error { + c.Stdout = io.MultiWriter(buf, c.Stdout) + c.Stderr = io.MultiWriter(buf, c.Stderr) + return nil + } +} diff --git a/libs/process/opts_test.go b/libs/process/opts_test.go new file mode 100644 index 00000000..3a819fbb --- /dev/null +++ b/libs/process/opts_test.go @@ -0,0 +1,47 @@ +package process + +import ( + "context" + "os/exec" + "runtime" + "sort" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/databricks/cli/libs/env" + "github.com/stretchr/testify/assert" +) + +func TestWithEnvs(t *testing.T) { + if runtime.GOOS == "windows" { + // Skipping test on windows for now because of the following error: + // /bin/sh -c echo $FOO $BAR: exec: "/bin/sh": file does not exist + t.SkipNow() + } + ctx := context.Background() + ctx2 := env.Set(ctx, "FOO", "foo") + res, err := Background(ctx2, []string{"/bin/sh", "-c", "echo $FOO $BAR"}, WithEnvs(map[string]string{ + "BAR": "delirium", + })) + assert.NoError(t, err) + assert.Equal(t, "foo delirium\n", res) +} + +func TestWorksWithLibsEnv(t *testing.T) { + testutil.CleanupEnvironment(t) + ctx := context.Background() + + cmd := &exec.Cmd{} + err := WithEnvs(map[string]string{ + "CCC": "DDD", + "EEE": "FFF", + })(ctx, cmd) + assert.NoError(t, err) + + vars := cmd.Environ() + sort.Strings(vars) + + assert.True(t, len(vars) >= 2) + assert.Equal(t, "CCC=DDD", vars[0]) + assert.Equal(t, "EEE=FFF", vars[1]) +} diff --git a/python/runner.go b/python/runner.go index bdf386a0..ebf24717 100644 --- a/python/runner.go +++ b/python/runner.go @@ -8,6 +8,8 @@ import ( "os/exec" "runtime" "strings" + + "github.com/databricks/cli/libs/process" ) func PyInline(ctx context.Context, inlinePy string) (string, error) { @@ -88,8 +90,8 @@ func DetectExecutable(ctx context.Context) (string, error) { func execAndPassErr(ctx context.Context, name string, args ...string) ([]byte, error) { // TODO: move out to a separate package, once we have Maven integration - out, err := exec.CommandContext(ctx, name, args...).Output() - return out, nicerErr(err) + out, err := process.Background(ctx, append([]string{name}, args...)) + return []byte(out), nicerErr(err) } func getFirstMatch(out string) string { diff --git a/python/runner_test.go b/python/runner_test.go index 3968e27a..fc8f2508 100644 --- a/python/runner_test.go +++ b/python/runner_test.go @@ -20,7 +20,7 @@ func TestExecAndPassError(t *testing.T) { } _, err := execAndPassErr(context.Background(), "which", "__non_existing__") - assert.EqualError(t, err, "exit status 1") + assert.EqualError(t, err, "which __non_existing__: exit status 1") } func TestDetectPython(t *testing.T) { @@ -77,7 +77,7 @@ func testTempdir(t *testing.T, dir *string) func() { func TestPyError(t *testing.T) { _, err := Py(context.Background(), "__non_existing__.py") - assert.Contains(t, err.Error(), "can't open file") + assert.Contains(t, err.Error(), "exit status 2") } func TestPyInline(t *testing.T) { @@ -90,5 +90,5 @@ func TestPyInlineStderr(t *testing.T) { DetectExecutable(context.Background()) inline := "import sys; sys.stderr.write('___msg___'); sys.exit(1)" _, err := PyInline(context.Background(), inline) - assert.EqualError(t, err, "___msg___") + assert.ErrorContains(t, err, "___msg___") }