diff --git a/bundle/bundle.go b/bundle/bundle.go index 8175ce28..4fc60539 100644 --- a/bundle/bundle.go +++ b/bundle/bundle.go @@ -14,6 +14,7 @@ import ( "sync" "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/env" "github.com/databricks/cli/folders" "github.com/databricks/cli/libs/git" "github.com/databricks/cli/libs/locker" @@ -51,8 +52,6 @@ type Bundle struct { AutoApprove bool } -const ExtraIncludePathsKey string = "DATABRICKS_BUNDLE_INCLUDES" - func Load(ctx context.Context, path string) (*Bundle, error) { bundle := &Bundle{} stat, err := os.Stat(path) @@ -61,9 +60,9 @@ func Load(ctx context.Context, path string) (*Bundle, error) { } configFile, err := config.FileNames.FindInPath(path) if err != nil { - _, hasIncludePathEnv := os.LookupEnv(ExtraIncludePathsKey) - _, hasBundleRootEnv := os.LookupEnv(envBundleRoot) - if hasIncludePathEnv && hasBundleRootEnv && stat.IsDir() { + _, hasRootEnv := env.Root(ctx) + _, hasIncludesEnv := env.Includes(ctx) + if hasRootEnv && hasIncludesEnv && stat.IsDir() { log.Debugf(ctx, "No bundle configuration; using bundle root: %s", path) bundle.Config = config.Root{ Path: path, @@ -86,7 +85,7 @@ func Load(ctx context.Context, path string) (*Bundle, error) { // MustLoad returns a bundle configuration. // It returns an error if a bundle was not found or could not be loaded. func MustLoad(ctx context.Context) (*Bundle, error) { - root, err := mustGetRoot() + root, err := mustGetRoot(ctx) if err != nil { return nil, err } @@ -98,7 +97,7 @@ func MustLoad(ctx context.Context) (*Bundle, error) { // It returns an error if a bundle was found but could not be loaded. // It returns a `nil` bundle if a bundle was not found. func TryLoad(ctx context.Context) (*Bundle, error) { - root, err := tryGetRoot() + root, err := tryGetRoot(ctx) if err != nil { return nil, err } @@ -124,13 +123,12 @@ func (b *Bundle) WorkspaceClient() *databricks.WorkspaceClient { // CacheDir returns directory to use for temporary files for this bundle. // Scoped to the bundle's target. -func (b *Bundle) CacheDir(paths ...string) (string, error) { +func (b *Bundle) CacheDir(ctx context.Context, paths ...string) (string, error) { if b.Config.Bundle.Target == "" { panic("target not set") } - cacheDirName, exists := os.LookupEnv("DATABRICKS_BUNDLE_TMP") - + cacheDirName, exists := env.TempDir(ctx) if !exists || cacheDirName == "" { cacheDirName = filepath.Join( // Anchor at bundle root directory. @@ -163,8 +161,8 @@ func (b *Bundle) CacheDir(paths ...string) (string, error) { // This directory is used to store and automaticaly sync internal bundle files, such as, f.e // notebook trampoline files for Python wheel and etc. -func (b *Bundle) InternalDir() (string, error) { - cacheDir, err := b.CacheDir() +func (b *Bundle) InternalDir(ctx context.Context) (string, error) { + cacheDir, err := b.CacheDir(ctx) if err != nil { return "", err } @@ -181,8 +179,8 @@ func (b *Bundle) InternalDir() (string, error) { // GetSyncIncludePatterns returns a list of user defined includes // And also adds InternalDir folder to include list for sync command // so this folder is always synced -func (b *Bundle) GetSyncIncludePatterns() ([]string, error) { - internalDir, err := b.InternalDir() +func (b *Bundle) GetSyncIncludePatterns(ctx context.Context) ([]string, error) { + internalDir, err := b.InternalDir(ctx) if err != nil { return nil, err } diff --git a/bundle/bundle_test.go b/bundle/bundle_test.go index 4a3e7f2c..43477efd 100644 --- a/bundle/bundle_test.go +++ b/bundle/bundle_test.go @@ -6,6 +6,7 @@ import ( "path/filepath" "testing" + "github.com/databricks/cli/bundle/env" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -23,12 +24,13 @@ func TestLoadExists(t *testing.T) { } func TestBundleCacheDir(t *testing.T) { + ctx := context.Background() projectDir := t.TempDir() f1, err := os.Create(filepath.Join(projectDir, "databricks.yml")) require.NoError(t, err) f1.Close() - bundle, err := Load(context.Background(), projectDir) + bundle, err := Load(ctx, projectDir) require.NoError(t, err) // Artificially set target. @@ -38,7 +40,7 @@ func TestBundleCacheDir(t *testing.T) { // unset env variable in case it's set t.Setenv("DATABRICKS_BUNDLE_TMP", "") - cacheDir, err := bundle.CacheDir() + cacheDir, err := bundle.CacheDir(ctx) // format is /.databricks/bundle/ assert.NoError(t, err) @@ -46,13 +48,14 @@ func TestBundleCacheDir(t *testing.T) { } func TestBundleCacheDirOverride(t *testing.T) { + ctx := context.Background() projectDir := t.TempDir() bundleTmpDir := t.TempDir() f1, err := os.Create(filepath.Join(projectDir, "databricks.yml")) require.NoError(t, err) f1.Close() - bundle, err := Load(context.Background(), projectDir) + bundle, err := Load(ctx, projectDir) require.NoError(t, err) // Artificially set target. @@ -62,7 +65,7 @@ func TestBundleCacheDirOverride(t *testing.T) { // now we expect to use 'bundleTmpDir' instead of CWD/.databricks/bundle t.Setenv("DATABRICKS_BUNDLE_TMP", bundleTmpDir) - cacheDir, err := bundle.CacheDir() + cacheDir, err := bundle.CacheDir(ctx) // format is / assert.NoError(t, err) @@ -70,14 +73,14 @@ func TestBundleCacheDirOverride(t *testing.T) { } func TestBundleMustLoadSuccess(t *testing.T) { - t.Setenv(envBundleRoot, "./tests/basic") + t.Setenv(env.RootVariable, "./tests/basic") b, err := MustLoad(context.Background()) require.NoError(t, err) assert.Equal(t, "tests/basic", filepath.ToSlash(b.Config.Path)) } func TestBundleMustLoadFailureWithEnv(t *testing.T) { - t.Setenv(envBundleRoot, "./tests/doesntexist") + t.Setenv(env.RootVariable, "./tests/doesntexist") _, err := MustLoad(context.Background()) require.Error(t, err, "not a directory") } @@ -89,14 +92,14 @@ func TestBundleMustLoadFailureIfNotFound(t *testing.T) { } func TestBundleTryLoadSuccess(t *testing.T) { - t.Setenv(envBundleRoot, "./tests/basic") + t.Setenv(env.RootVariable, "./tests/basic") b, err := TryLoad(context.Background()) require.NoError(t, err) assert.Equal(t, "tests/basic", filepath.ToSlash(b.Config.Path)) } func TestBundleTryLoadFailureWithEnv(t *testing.T) { - t.Setenv(envBundleRoot, "./tests/doesntexist") + t.Setenv(env.RootVariable, "./tests/doesntexist") _, err := TryLoad(context.Background()) require.Error(t, err, "not a directory") } diff --git a/bundle/config/mutator/override_compute.go b/bundle/config/mutator/override_compute.go index ee2e2a82..21d95013 100644 --- a/bundle/config/mutator/override_compute.go +++ b/bundle/config/mutator/override_compute.go @@ -3,11 +3,11 @@ package mutator import ( "context" "fmt" - "os" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config/resources" + "github.com/databricks/cli/libs/env" ) type overrideCompute struct{} @@ -39,8 +39,8 @@ func (m *overrideCompute) Apply(ctx context.Context, b *bundle.Bundle) error { } return nil } - if os.Getenv("DATABRICKS_CLUSTER_ID") != "" { - b.Config.Bundle.ComputeID = os.Getenv("DATABRICKS_CLUSTER_ID") + if v := env.Get(ctx, "DATABRICKS_CLUSTER_ID"); v != "" { + b.Config.Bundle.ComputeID = v } if b.Config.Bundle.ComputeID == "" { diff --git a/bundle/config/mutator/process_root_includes.go b/bundle/config/mutator/process_root_includes.go index 98992872..5a5ab1b1 100644 --- a/bundle/config/mutator/process_root_includes.go +++ b/bundle/config/mutator/process_root_includes.go @@ -10,11 +10,12 @@ import ( "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/env" ) // Get extra include paths from environment variable -func GetExtraIncludePaths() []string { - value, exists := os.LookupEnv(bundle.ExtraIncludePathsKey) +func getExtraIncludePaths(ctx context.Context) []string { + value, exists := env.Includes(ctx) if !exists { return nil } @@ -48,7 +49,7 @@ func (m *processRootIncludes) Apply(ctx context.Context, b *bundle.Bundle) error var files []string // Converts extra include paths from environment variable to relative paths - for _, extraIncludePath := range GetExtraIncludePaths() { + for _, extraIncludePath := range getExtraIncludePaths(ctx) { if filepath.IsAbs(extraIncludePath) { rel, err := filepath.Rel(b.Config.Path, extraIncludePath) if err != nil { diff --git a/bundle/config/mutator/process_root_includes_test.go b/bundle/config/mutator/process_root_includes_test.go index 1ce094bc..aec9b32d 100644 --- a/bundle/config/mutator/process_root_includes_test.go +++ b/bundle/config/mutator/process_root_includes_test.go @@ -2,16 +2,17 @@ package mutator_test import ( "context" - "fmt" "os" "path" "path/filepath" "runtime" + "strings" "testing" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config/mutator" + "github.com/databricks/cli/bundle/env" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -129,10 +130,7 @@ func TestProcessRootIncludesExtrasFromEnvVar(t *testing.T) { rootPath := t.TempDir() testYamlName := "extra_include_path.yml" touch(t, rootPath, testYamlName) - os.Setenv(bundle.ExtraIncludePathsKey, path.Join(rootPath, testYamlName)) - t.Cleanup(func() { - os.Unsetenv(bundle.ExtraIncludePathsKey) - }) + t.Setenv(env.IncludesVariable, path.Join(rootPath, testYamlName)) bundle := &bundle.Bundle{ Config: config.Root{ @@ -149,7 +147,13 @@ func TestProcessRootIncludesDedupExtrasFromEnvVar(t *testing.T) { rootPath := t.TempDir() testYamlName := "extra_include_path.yml" touch(t, rootPath, testYamlName) - t.Setenv(bundle.ExtraIncludePathsKey, fmt.Sprintf("%s%s%s", path.Join(rootPath, testYamlName), string(os.PathListSeparator), path.Join(rootPath, testYamlName))) + t.Setenv(env.IncludesVariable, strings.Join( + []string{ + path.Join(rootPath, testYamlName), + path.Join(rootPath, testYamlName), + }, + string(os.PathListSeparator), + )) bundle := &bundle.Bundle{ Config: config.Root{ diff --git a/bundle/config/mutator/set_variables.go b/bundle/config/mutator/set_variables.go index 427b6dce..4bf8ff82 100644 --- a/bundle/config/mutator/set_variables.go +++ b/bundle/config/mutator/set_variables.go @@ -3,10 +3,10 @@ package mutator import ( "context" "fmt" - "os" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config/variable" + "github.com/databricks/cli/libs/env" ) const bundleVarPrefix = "BUNDLE_VAR_" @@ -21,7 +21,7 @@ func (m *setVariables) Name() string { return "SetVariables" } -func setVariable(v *variable.Variable, name string) error { +func setVariable(ctx context.Context, v *variable.Variable, name string) error { // case: variable already has value initialized, so skip if v.HasValue() { return nil @@ -29,7 +29,7 @@ func setVariable(v *variable.Variable, name string) error { // case: read and set variable value from process environment envVarName := bundleVarPrefix + name - if val, ok := os.LookupEnv(envVarName); ok { + if val, ok := env.Lookup(ctx, envVarName); ok { err := v.Set(val) if err != nil { return fmt.Errorf(`failed to assign value "%s" to variable %s from environment variable %s with error: %w`, val, name, envVarName, err) @@ -54,7 +54,7 @@ func setVariable(v *variable.Variable, name string) error { func (m *setVariables) Apply(ctx context.Context, b *bundle.Bundle) error { for name, variable := range b.Config.Variables { - err := setVariable(variable, name) + err := setVariable(ctx, variable, name) if err != nil { return err } diff --git a/bundle/config/mutator/set_variables_test.go b/bundle/config/mutator/set_variables_test.go index 91948aa4..323f1e86 100644 --- a/bundle/config/mutator/set_variables_test.go +++ b/bundle/config/mutator/set_variables_test.go @@ -21,7 +21,7 @@ func TestSetVariableFromProcessEnvVar(t *testing.T) { // set value for variable as an environment variable t.Setenv("BUNDLE_VAR_foo", "process-env") - err := setVariable(&variable, "foo") + err := setVariable(context.Background(), &variable, "foo") require.NoError(t, err) assert.Equal(t, *variable.Value, "process-env") } @@ -33,7 +33,7 @@ func TestSetVariableUsingDefaultValue(t *testing.T) { Default: &defaultVal, } - err := setVariable(&variable, "foo") + err := setVariable(context.Background(), &variable, "foo") require.NoError(t, err) assert.Equal(t, *variable.Value, "default") } @@ -49,7 +49,7 @@ func TestSetVariableWhenAlreadyAValueIsAssigned(t *testing.T) { // since a value is already assigned to the variable, it would not be overridden // by the default value - err := setVariable(&variable, "foo") + err := setVariable(context.Background(), &variable, "foo") require.NoError(t, err) assert.Equal(t, *variable.Value, "assigned-value") } @@ -68,7 +68,7 @@ func TestSetVariableEnvVarValueDoesNotOverridePresetValue(t *testing.T) { // since a value is already assigned to the variable, it would not be overridden // by the value from environment - err := setVariable(&variable, "foo") + err := setVariable(context.Background(), &variable, "foo") require.NoError(t, err) assert.Equal(t, *variable.Value, "assigned-value") } @@ -79,7 +79,7 @@ func TestSetVariablesErrorsIfAValueCouldNotBeResolved(t *testing.T) { } // fails because we could not resolve a value for the variable - err := setVariable(&variable, "foo") + err := setVariable(context.Background(), &variable, "foo") assert.ErrorContains(t, err, "no value assigned to required variable foo. Assignment can be done through the \"--var\" flag or by setting the BUNDLE_VAR_foo environment variable") } diff --git a/bundle/config/mutator/trampoline.go b/bundle/config/mutator/trampoline.go index 7c06c7fa..52d62c1b 100644 --- a/bundle/config/mutator/trampoline.go +++ b/bundle/config/mutator/trampoline.go @@ -43,7 +43,7 @@ func (m *trampoline) Name() string { func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error { tasks := m.functions.GetTasks(b) for _, task := range tasks { - err := m.generateNotebookWrapper(b, task) + err := m.generateNotebookWrapper(ctx, b, task) if err != nil { return err } @@ -51,8 +51,8 @@ func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error { return nil } -func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task TaskWithJobKey) error { - internalDir, err := b.InternalDir() +func (m *trampoline) generateNotebookWrapper(ctx context.Context, b *bundle.Bundle, task TaskWithJobKey) error { + internalDir, err := b.InternalDir(ctx) if err != nil { return err } diff --git a/bundle/config/mutator/trampoline_test.go b/bundle/config/mutator/trampoline_test.go index aec58618..a3e06b30 100644 --- a/bundle/config/mutator/trampoline_test.go +++ b/bundle/config/mutator/trampoline_test.go @@ -83,7 +83,7 @@ func TestGenerateTrampoline(t *testing.T) { err := bundle.Apply(ctx, b, trampoline) require.NoError(t, err) - dir, err := b.InternalDir() + dir, err := b.InternalDir(ctx) require.NoError(t, err) filename := filepath.Join(dir, "notebook_test_to_trampoline.py") diff --git a/bundle/deploy/files/sync.go b/bundle/deploy/files/sync.go index 2dccd20a..ff3d78d0 100644 --- a/bundle/deploy/files/sync.go +++ b/bundle/deploy/files/sync.go @@ -9,12 +9,12 @@ import ( ) func getSync(ctx context.Context, b *bundle.Bundle) (*sync.Sync, error) { - cacheDir, err := b.CacheDir() + cacheDir, err := b.CacheDir(ctx) if err != nil { return nil, fmt.Errorf("cannot get bundle cache directory: %w", err) } - includes, err := b.GetSyncIncludePatterns() + includes, err := b.GetSyncIncludePatterns(ctx) if err != nil { return nil, fmt.Errorf("cannot get list of sync includes: %w", err) } diff --git a/bundle/deploy/terraform/dir.go b/bundle/deploy/terraform/dir.go index 9f83b8da..b7b086ce 100644 --- a/bundle/deploy/terraform/dir.go +++ b/bundle/deploy/terraform/dir.go @@ -1,11 +1,13 @@ package terraform import ( + "context" + "github.com/databricks/cli/bundle" ) // Dir returns the Terraform working directory for a given bundle. // The working directory is emphemeral and nested under the bundle's cache directory. -func Dir(b *bundle.Bundle) (string, error) { - return b.CacheDir("terraform") +func Dir(ctx context.Context, b *bundle.Bundle) (string, error) { + return b.CacheDir(ctx, "terraform") } diff --git a/bundle/deploy/terraform/init.go b/bundle/deploy/terraform/init.go index 60f0a6c4..aa1dff74 100644 --- a/bundle/deploy/terraform/init.go +++ b/bundle/deploy/terraform/init.go @@ -12,6 +12,7 @@ import ( "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/log" "github.com/hashicorp/go-version" "github.com/hashicorp/hc-install/product" @@ -38,7 +39,7 @@ func (m *initialize) findExecPath(ctx context.Context, b *bundle.Bundle, tf *con return tf.ExecPath, nil } - binDir, err := b.CacheDir("bin") + binDir, err := b.CacheDir(context.Background(), "bin") if err != nil { return "", err } @@ -73,25 +74,25 @@ func (m *initialize) findExecPath(ctx context.Context, b *bundle.Bundle, tf *con } // This function inherits some environment variables for Terraform CLI. -func inheritEnvVars(env map[string]string) error { +func inheritEnvVars(ctx context.Context, environ map[string]string) error { // Include $HOME in set of environment variables to pass along. - home, ok := os.LookupEnv("HOME") + home, ok := env.Lookup(ctx, "HOME") if ok { - env["HOME"] = home + environ["HOME"] = home } // Include $PATH in set of environment variables to pass along. // This is necessary to ensure that our Terraform provider can use the // same auxiliary programs (e.g. `az`, or `gcloud`) as the CLI. - path, ok := os.LookupEnv("PATH") + path, ok := env.Lookup(ctx, "PATH") if ok { - env["PATH"] = path + environ["PATH"] = path } // Include $TF_CLI_CONFIG_FILE to override terraform provider in development. - configFile, ok := os.LookupEnv("TF_CLI_CONFIG_FILE") + configFile, ok := env.Lookup(ctx, "TF_CLI_CONFIG_FILE") if ok { - env["TF_CLI_CONFIG_FILE"] = configFile + environ["TF_CLI_CONFIG_FILE"] = configFile } return nil @@ -105,40 +106,40 @@ func inheritEnvVars(env map[string]string) error { // the CLI and its dependencies do not have access to. // // see: os.TempDir for more context -func setTempDirEnvVars(env map[string]string, b *bundle.Bundle) error { +func setTempDirEnvVars(ctx context.Context, environ map[string]string, b *bundle.Bundle) error { switch runtime.GOOS { case "windows": - if v, ok := os.LookupEnv("TMP"); ok { - env["TMP"] = v - } else if v, ok := os.LookupEnv("TEMP"); ok { - env["TEMP"] = v - } else if v, ok := os.LookupEnv("USERPROFILE"); ok { - env["USERPROFILE"] = v + if v, ok := env.Lookup(ctx, "TMP"); ok { + environ["TMP"] = v + } else if v, ok := env.Lookup(ctx, "TEMP"); ok { + environ["TEMP"] = v + } else if v, ok := env.Lookup(ctx, "USERPROFILE"); ok { + environ["USERPROFILE"] = v } else { - tmpDir, err := b.CacheDir("tmp") + tmpDir, err := b.CacheDir(ctx, "tmp") if err != nil { return err } - env["TMP"] = tmpDir + environ["TMP"] = tmpDir } default: // If TMPDIR is not set, we let the process fall back to its default value. - if v, ok := os.LookupEnv("TMPDIR"); ok { - env["TMPDIR"] = v + if v, ok := env.Lookup(ctx, "TMPDIR"); ok { + environ["TMPDIR"] = v } } return nil } // This function passes through all proxy related environment variables. -func setProxyEnvVars(env map[string]string, b *bundle.Bundle) error { +func setProxyEnvVars(ctx context.Context, environ map[string]string, b *bundle.Bundle) error { for _, v := range []string{"http_proxy", "https_proxy", "no_proxy"} { // The case (upper or lower) is notoriously inconsistent for tools on Unix systems. // We therefore try to read both the upper and lower case versions of the variable. for _, v := range []string{strings.ToUpper(v), strings.ToLower(v)} { - if val, ok := os.LookupEnv(v); ok { + if val, ok := env.Lookup(ctx, v); ok { // Only set uppercase version of the variable. - env[strings.ToUpper(v)] = val + environ[strings.ToUpper(v)] = val } } } @@ -157,7 +158,7 @@ func (m *initialize) Apply(ctx context.Context, b *bundle.Bundle) error { return err } - workingDir, err := Dir(b) + workingDir, err := Dir(ctx, b) if err != nil { return err } @@ -167,31 +168,31 @@ func (m *initialize) Apply(ctx context.Context, b *bundle.Bundle) error { return err } - env, err := b.AuthEnv() + environ, err := b.AuthEnv() if err != nil { return err } - err = inheritEnvVars(env) + err = inheritEnvVars(ctx, environ) if err != nil { return err } // Set the temporary directory environment variables - err = setTempDirEnvVars(env, b) + err = setTempDirEnvVars(ctx, environ, b) if err != nil { return err } // Set the proxy related environment variables - err = setProxyEnvVars(env, b) + err = setProxyEnvVars(ctx, environ, b) if err != nil { return err } // Configure environment variables for auth for Terraform to use. - log.Debugf(ctx, "Environment variables for Terraform: %s", strings.Join(maps.Keys(env), ", ")) - err = tf.SetEnv(env) + log.Debugf(ctx, "Environment variables for Terraform: %s", strings.Join(maps.Keys(environ), ", ")) + err = tf.SetEnv(environ) if err != nil { return err } diff --git a/bundle/deploy/terraform/init_test.go b/bundle/deploy/terraform/init_test.go index b9459387..001e7a22 100644 --- a/bundle/deploy/terraform/init_test.go +++ b/bundle/deploy/terraform/init_test.go @@ -68,7 +68,7 @@ func TestSetTempDirEnvVarsForUnixWithTmpDirSet(t *testing.T) { // compute env env := make(map[string]string, 0) - err := setTempDirEnvVars(env, b) + err := setTempDirEnvVars(context.Background(), env, b) require.NoError(t, err) // Assert that we pass through TMPDIR. @@ -96,7 +96,7 @@ func TestSetTempDirEnvVarsForUnixWithTmpDirNotSet(t *testing.T) { // compute env env := make(map[string]string, 0) - err := setTempDirEnvVars(env, b) + err := setTempDirEnvVars(context.Background(), env, b) require.NoError(t, err) // Assert that we don't pass through TMPDIR. @@ -124,7 +124,7 @@ func TestSetTempDirEnvVarsForWindowWithAllTmpDirEnvVarsSet(t *testing.T) { // compute env env := make(map[string]string, 0) - err := setTempDirEnvVars(env, b) + err := setTempDirEnvVars(context.Background(), env, b) require.NoError(t, err) // assert that we pass through the highest priority env var value @@ -154,7 +154,7 @@ func TestSetTempDirEnvVarsForWindowWithUserProfileAndTempSet(t *testing.T) { // compute env env := make(map[string]string, 0) - err := setTempDirEnvVars(env, b) + err := setTempDirEnvVars(context.Background(), env, b) require.NoError(t, err) // assert that we pass through the highest priority env var value @@ -184,7 +184,7 @@ func TestSetTempDirEnvVarsForWindowWithUserProfileSet(t *testing.T) { // compute env env := make(map[string]string, 0) - err := setTempDirEnvVars(env, b) + err := setTempDirEnvVars(context.Background(), env, b) require.NoError(t, err) // assert that we pass through the user profile @@ -214,11 +214,11 @@ func TestSetTempDirEnvVarsForWindowsWithoutAnyTempDirEnvVarsSet(t *testing.T) { // compute env env := make(map[string]string, 0) - err := setTempDirEnvVars(env, b) + err := setTempDirEnvVars(context.Background(), env, b) require.NoError(t, err) // assert TMP is set to b.CacheDir("tmp") - tmpDir, err := b.CacheDir("tmp") + tmpDir, err := b.CacheDir(context.Background(), "tmp") require.NoError(t, err) assert.Equal(t, map[string]string{ "TMP": tmpDir, @@ -248,7 +248,7 @@ func TestSetProxyEnvVars(t *testing.T) { // No proxy env vars set. clearEnv() env := make(map[string]string, 0) - err := setProxyEnvVars(env, b) + err := setProxyEnvVars(context.Background(), env, b) require.NoError(t, err) assert.Len(t, env, 0) @@ -258,7 +258,7 @@ func TestSetProxyEnvVars(t *testing.T) { t.Setenv("https_proxy", "foo") t.Setenv("no_proxy", "foo") env = make(map[string]string, 0) - err = setProxyEnvVars(env, b) + err = setProxyEnvVars(context.Background(), env, b) require.NoError(t, err) assert.ElementsMatch(t, []string{"HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY"}, maps.Keys(env)) @@ -268,7 +268,7 @@ func TestSetProxyEnvVars(t *testing.T) { t.Setenv("HTTPS_PROXY", "foo") t.Setenv("NO_PROXY", "foo") env = make(map[string]string, 0) - err = setProxyEnvVars(env, b) + err = setProxyEnvVars(context.Background(), env, b) require.NoError(t, err) assert.ElementsMatch(t, []string{"HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY"}, maps.Keys(env)) } @@ -280,7 +280,7 @@ func TestInheritEnvVars(t *testing.T) { t.Setenv("PATH", "/foo:/bar") t.Setenv("TF_CLI_CONFIG_FILE", "/tmp/config.tfrc") - err := inheritEnvVars(env) + err := inheritEnvVars(context.Background(), env) require.NoError(t, err) diff --git a/bundle/deploy/terraform/plan.go b/bundle/deploy/terraform/plan.go index a725b4aa..ff841148 100644 --- a/bundle/deploy/terraform/plan.go +++ b/bundle/deploy/terraform/plan.go @@ -40,7 +40,7 @@ func (p *plan) Apply(ctx context.Context, b *bundle.Bundle) error { } // Persist computed plan - tfDir, err := Dir(b) + tfDir, err := Dir(ctx, b) if err != nil { return err } diff --git a/bundle/deploy/terraform/state_pull.go b/bundle/deploy/terraform/state_pull.go index e5a42d89..6dd12ccf 100644 --- a/bundle/deploy/terraform/state_pull.go +++ b/bundle/deploy/terraform/state_pull.go @@ -25,7 +25,7 @@ func (l *statePull) Apply(ctx context.Context, b *bundle.Bundle) error { return err } - dir, err := Dir(b) + dir, err := Dir(ctx, b) if err != nil { return err } diff --git a/bundle/deploy/terraform/state_push.go b/bundle/deploy/terraform/state_push.go index 0cd69e52..ae1d8b8b 100644 --- a/bundle/deploy/terraform/state_push.go +++ b/bundle/deploy/terraform/state_push.go @@ -22,7 +22,7 @@ func (l *statePush) Apply(ctx context.Context, b *bundle.Bundle) error { return err } - dir, err := Dir(b) + dir, err := Dir(ctx, b) if err != nil { return err } diff --git a/bundle/deploy/terraform/write.go b/bundle/deploy/terraform/write.go index 0bf9ab24..eca79ad2 100644 --- a/bundle/deploy/terraform/write.go +++ b/bundle/deploy/terraform/write.go @@ -16,7 +16,7 @@ func (w *write) Name() string { } func (w *write) Apply(ctx context.Context, b *bundle.Bundle) error { - dir, err := Dir(b) + dir, err := Dir(ctx, b) if err != nil { return err } diff --git a/bundle/env/env.go b/bundle/env/env.go new file mode 100644 index 00000000..ed2a13c7 --- /dev/null +++ b/bundle/env/env.go @@ -0,0 +1,18 @@ +package env + +import ( + "context" + + envlib "github.com/databricks/cli/libs/env" +) + +// Return the value of the first environment variable that is set. +func get(ctx context.Context, variables []string) (string, bool) { + for _, v := range variables { + value, ok := envlib.Lookup(ctx, v) + if ok { + return value, true + } + } + return "", false +} diff --git a/bundle/env/env_test.go b/bundle/env/env_test.go new file mode 100644 index 00000000..d900242e --- /dev/null +++ b/bundle/env/env_test.go @@ -0,0 +1,44 @@ +package env + +import ( + "context" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetWithRealEnvSingleVariable(t *testing.T) { + testutil.CleanupEnvironment(t) + t.Setenv("v1", "foo") + + v, ok := get(context.Background(), []string{"v1"}) + require.True(t, ok) + assert.Equal(t, "foo", v) + + // Not set. + v, ok = get(context.Background(), []string{"v2"}) + require.False(t, ok) + assert.Equal(t, "", v) +} + +func TestGetWithRealEnvMultipleVariables(t *testing.T) { + testutil.CleanupEnvironment(t) + t.Setenv("v1", "foo") + + for _, vars := range [][]string{ + {"v1", "v2", "v3"}, + {"v2", "v3", "v1"}, + {"v3", "v1", "v2"}, + } { + v, ok := get(context.Background(), vars) + require.True(t, ok) + assert.Equal(t, "foo", v) + } + + // Not set. + v, ok := get(context.Background(), []string{"v2", "v3", "v4"}) + require.False(t, ok) + assert.Equal(t, "", v) +} diff --git a/bundle/env/includes.go b/bundle/env/includes.go new file mode 100644 index 00000000..4ade0187 --- /dev/null +++ b/bundle/env/includes.go @@ -0,0 +1,14 @@ +package env + +import "context" + +// IncludesVariable names the environment variable that holds additional configuration paths to include +// during bundle configuration loading. Also see `bundle/config/mutator/process_root_includes.go`. +const IncludesVariable = "DATABRICKS_BUNDLE_INCLUDES" + +// Includes returns the bundle Includes environment variable. +func Includes(ctx context.Context) (string, bool) { + return get(ctx, []string{ + IncludesVariable, + }) +} diff --git a/bundle/env/includes_test.go b/bundle/env/includes_test.go new file mode 100644 index 00000000..d9366a59 --- /dev/null +++ b/bundle/env/includes_test.go @@ -0,0 +1,28 @@ +package env + +import ( + "context" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/stretchr/testify/assert" +) + +func TestIncludes(t *testing.T) { + ctx := context.Background() + + testutil.CleanupEnvironment(t) + + t.Run("set", func(t *testing.T) { + t.Setenv("DATABRICKS_BUNDLE_INCLUDES", "foo") + includes, ok := Includes(ctx) + assert.True(t, ok) + assert.Equal(t, "foo", includes) + }) + + t.Run("not set", func(t *testing.T) { + includes, ok := Includes(ctx) + assert.False(t, ok) + assert.Equal(t, "", includes) + }) +} diff --git a/bundle/env/root.go b/bundle/env/root.go new file mode 100644 index 00000000..e3c2a38a --- /dev/null +++ b/bundle/env/root.go @@ -0,0 +1,16 @@ +package env + +import "context" + +// RootVariable names the environment variable that holds the bundle root path. +const RootVariable = "DATABRICKS_BUNDLE_ROOT" + +// Root returns the bundle root environment variable. +func Root(ctx context.Context) (string, bool) { + return get(ctx, []string{ + RootVariable, + + // Primary variable name for the bundle root until v0.204.0. + "BUNDLE_ROOT", + }) +} diff --git a/bundle/env/root_test.go b/bundle/env/root_test.go new file mode 100644 index 00000000..fc2d6e20 --- /dev/null +++ b/bundle/env/root_test.go @@ -0,0 +1,43 @@ +package env + +import ( + "context" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/stretchr/testify/assert" +) + +func TestRoot(t *testing.T) { + ctx := context.Background() + + testutil.CleanupEnvironment(t) + + t.Run("first", func(t *testing.T) { + t.Setenv("DATABRICKS_BUNDLE_ROOT", "foo") + root, ok := Root(ctx) + assert.True(t, ok) + assert.Equal(t, "foo", root) + }) + + t.Run("second", func(t *testing.T) { + t.Setenv("BUNDLE_ROOT", "foo") + root, ok := Root(ctx) + assert.True(t, ok) + assert.Equal(t, "foo", root) + }) + + t.Run("both set", func(t *testing.T) { + t.Setenv("DATABRICKS_BUNDLE_ROOT", "first") + t.Setenv("BUNDLE_ROOT", "second") + root, ok := Root(ctx) + assert.True(t, ok) + assert.Equal(t, "first", root) + }) + + t.Run("not set", func(t *testing.T) { + root, ok := Root(ctx) + assert.False(t, ok) + assert.Equal(t, "", root) + }) +} diff --git a/bundle/env/target.go b/bundle/env/target.go new file mode 100644 index 00000000..ac3b4887 --- /dev/null +++ b/bundle/env/target.go @@ -0,0 +1,17 @@ +package env + +import "context" + +// TargetVariable names the environment variable that holds the bundle target to use. +const TargetVariable = "DATABRICKS_BUNDLE_TARGET" + +// Target returns the bundle target environment variable. +func Target(ctx context.Context) (string, bool) { + return get(ctx, []string{ + TargetVariable, + + // Primary variable name for the bundle target until v0.203.2. + // See https://github.com/databricks/cli/pull/670. + "DATABRICKS_BUNDLE_ENV", + }) +} diff --git a/bundle/env/target_test.go b/bundle/env/target_test.go new file mode 100644 index 00000000..0c15bf91 --- /dev/null +++ b/bundle/env/target_test.go @@ -0,0 +1,43 @@ +package env + +import ( + "context" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/stretchr/testify/assert" +) + +func TestTarget(t *testing.T) { + ctx := context.Background() + + testutil.CleanupEnvironment(t) + + t.Run("first", func(t *testing.T) { + t.Setenv("DATABRICKS_BUNDLE_TARGET", "foo") + target, ok := Target(ctx) + assert.True(t, ok) + assert.Equal(t, "foo", target) + }) + + t.Run("second", func(t *testing.T) { + t.Setenv("DATABRICKS_BUNDLE_ENV", "foo") + target, ok := Target(ctx) + assert.True(t, ok) + assert.Equal(t, "foo", target) + }) + + t.Run("both set", func(t *testing.T) { + t.Setenv("DATABRICKS_BUNDLE_TARGET", "first") + t.Setenv("DATABRICKS_BUNDLE_ENV", "second") + target, ok := Target(ctx) + assert.True(t, ok) + assert.Equal(t, "first", target) + }) + + t.Run("not set", func(t *testing.T) { + target, ok := Target(ctx) + assert.False(t, ok) + assert.Equal(t, "", target) + }) +} diff --git a/bundle/env/temp_dir.go b/bundle/env/temp_dir.go new file mode 100644 index 00000000..b9133907 --- /dev/null +++ b/bundle/env/temp_dir.go @@ -0,0 +1,13 @@ +package env + +import "context" + +// TempDirVariable names the environment variable that holds the temporary directory to use. +const TempDirVariable = "DATABRICKS_BUNDLE_TMP" + +// TempDir returns the temporary directory to use. +func TempDir(ctx context.Context) (string, bool) { + return get(ctx, []string{ + TempDirVariable, + }) +} diff --git a/bundle/env/temp_dir_test.go b/bundle/env/temp_dir_test.go new file mode 100644 index 00000000..7659bac6 --- /dev/null +++ b/bundle/env/temp_dir_test.go @@ -0,0 +1,28 @@ +package env + +import ( + "context" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/stretchr/testify/assert" +) + +func TestTempDir(t *testing.T) { + ctx := context.Background() + + testutil.CleanupEnvironment(t) + + t.Run("set", func(t *testing.T) { + t.Setenv("DATABRICKS_BUNDLE_TMP", "foo") + tempDir, ok := TempDir(ctx) + assert.True(t, ok) + assert.Equal(t, "foo", tempDir) + }) + + t.Run("not set", func(t *testing.T) { + tempDir, ok := TempDir(ctx) + assert.False(t, ok) + assert.Equal(t, "", tempDir) + }) +} diff --git a/bundle/root.go b/bundle/root.go index 46f63e13..7518bf5f 100644 --- a/bundle/root.go +++ b/bundle/root.go @@ -1,21 +1,21 @@ package bundle import ( + "context" "fmt" "os" "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/env" "github.com/databricks/cli/folders" ) -const envBundleRoot = "BUNDLE_ROOT" - -// getRootEnv returns the value of the `BUNDLE_ROOT` environment variable +// getRootEnv returns the value of the bundle root environment variable // if it set and is a directory. If the environment variable is set but // is not a directory, it returns an error. If the environment variable is // not set, it returns an empty string. -func getRootEnv() (string, error) { - path, ok := os.LookupEnv(envBundleRoot) +func getRootEnv(ctx context.Context) (string, error) { + path, ok := env.Root(ctx) if !ok { return "", nil } @@ -24,7 +24,7 @@ func getRootEnv() (string, error) { err = fmt.Errorf("not a directory") } if err != nil { - return "", fmt.Errorf(`invalid bundle root %s="%s": %w`, envBundleRoot, path, err) + return "", fmt.Errorf(`invalid bundle root %s="%s": %w`, env.RootVariable, path, err) } return path, nil } @@ -48,8 +48,8 @@ func getRootWithTraversal() (string, error) { } // mustGetRoot returns a bundle root or an error if one cannot be found. -func mustGetRoot() (string, error) { - path, err := getRootEnv() +func mustGetRoot(ctx context.Context) (string, error) { + path, err := getRootEnv(ctx) if path != "" || err != nil { return path, err } @@ -57,9 +57,9 @@ func mustGetRoot() (string, error) { } // tryGetRoot returns a bundle root or an empty string if one cannot be found. -func tryGetRoot() (string, error) { +func tryGetRoot(ctx context.Context) (string, error) { // Note: an invalid value in the environment variable is still an error. - path, err := getRootEnv() + path, err := getRootEnv(ctx) if path != "" || err != nil { return path, err } diff --git a/bundle/root_test.go b/bundle/root_test.go index 0c4c46aa..88113546 100644 --- a/bundle/root_test.go +++ b/bundle/root_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/env" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -32,49 +33,55 @@ func chdir(t *testing.T, dir string) string { } func TestRootFromEnv(t *testing.T) { + ctx := context.Background() dir := t.TempDir() - t.Setenv(envBundleRoot, dir) + t.Setenv(env.RootVariable, dir) // It should pull the root from the environment variable. - root, err := mustGetRoot() + root, err := mustGetRoot(ctx) require.NoError(t, err) require.Equal(t, root, dir) } func TestRootFromEnvDoesntExist(t *testing.T) { + ctx := context.Background() dir := t.TempDir() - t.Setenv(envBundleRoot, filepath.Join(dir, "doesntexist")) + t.Setenv(env.RootVariable, filepath.Join(dir, "doesntexist")) // It should pull the root from the environment variable. - _, err := mustGetRoot() + _, err := mustGetRoot(ctx) require.Errorf(t, err, "invalid bundle root") } func TestRootFromEnvIsFile(t *testing.T) { + ctx := context.Background() dir := t.TempDir() f, err := os.Create(filepath.Join(dir, "invalid")) require.NoError(t, err) f.Close() - t.Setenv(envBundleRoot, f.Name()) + t.Setenv(env.RootVariable, f.Name()) // It should pull the root from the environment variable. - _, err = mustGetRoot() + _, err = mustGetRoot(ctx) require.Errorf(t, err, "invalid bundle root") } func TestRootIfEnvIsEmpty(t *testing.T) { + ctx := context.Background() dir := "" - t.Setenv(envBundleRoot, dir) + t.Setenv(env.RootVariable, dir) // It should pull the root from the environment variable. - _, err := mustGetRoot() + _, err := mustGetRoot(ctx) require.Errorf(t, err, "invalid bundle root") } func TestRootLookup(t *testing.T) { + ctx := context.Background() + // Have to set then unset to allow the testing package to revert it to its original value. - t.Setenv(envBundleRoot, "") - os.Unsetenv(envBundleRoot) + t.Setenv(env.RootVariable, "") + os.Unsetenv(env.RootVariable) chdir(t, t.TempDir()) @@ -89,27 +96,30 @@ func TestRootLookup(t *testing.T) { // It should find the project root from $PWD. wd := chdir(t, "./a/b/c") - root, err := mustGetRoot() + root, err := mustGetRoot(ctx) require.NoError(t, err) require.Equal(t, wd, root) } func TestRootLookupError(t *testing.T) { + ctx := context.Background() + // Have to set then unset to allow the testing package to revert it to its original value. - t.Setenv(envBundleRoot, "") - os.Unsetenv(envBundleRoot) + t.Setenv(env.RootVariable, "") + os.Unsetenv(env.RootVariable) // It can't find a project root from a temporary directory. _ = chdir(t, t.TempDir()) - _, err := mustGetRoot() + _, err := mustGetRoot(ctx) require.ErrorContains(t, err, "unable to locate bundle root") } func TestLoadYamlWhenIncludesEnvPresent(t *testing.T) { + ctx := context.Background() chdir(t, filepath.Join(".", "tests", "basic")) - t.Setenv(ExtraIncludePathsKey, "test") + t.Setenv(env.IncludesVariable, "test") - bundle, err := MustLoad(context.Background()) + bundle, err := MustLoad(ctx) assert.NoError(t, err) assert.Equal(t, "basic", bundle.Config.Bundle.Name) @@ -119,30 +129,33 @@ func TestLoadYamlWhenIncludesEnvPresent(t *testing.T) { } func TestLoadDefautlBundleWhenNoYamlAndRootAndIncludesEnvPresent(t *testing.T) { + ctx := context.Background() dir := t.TempDir() chdir(t, dir) - t.Setenv(envBundleRoot, dir) - t.Setenv(ExtraIncludePathsKey, "test") + t.Setenv(env.RootVariable, dir) + t.Setenv(env.IncludesVariable, "test") - bundle, err := MustLoad(context.Background()) + bundle, err := MustLoad(ctx) assert.NoError(t, err) assert.Equal(t, dir, bundle.Config.Path) } func TestErrorIfNoYamlNoRootEnvAndIncludesEnvPresent(t *testing.T) { + ctx := context.Background() dir := t.TempDir() chdir(t, dir) - t.Setenv(ExtraIncludePathsKey, "test") + t.Setenv(env.IncludesVariable, "test") - _, err := MustLoad(context.Background()) + _, err := MustLoad(ctx) assert.Error(t, err) } func TestErrorIfNoYamlNoIncludesEnvAndRootEnvPresent(t *testing.T) { + ctx := context.Background() dir := t.TempDir() chdir(t, dir) - t.Setenv(envBundleRoot, dir) + t.Setenv(env.RootVariable, dir) - _, err := MustLoad(context.Background()) + _, err := MustLoad(ctx) assert.Error(t, err) } diff --git a/cmd/bundle/sync.go b/cmd/bundle/sync.go index be45626a..6d6a6f5a 100644 --- a/cmd/bundle/sync.go +++ b/cmd/bundle/sync.go @@ -18,12 +18,12 @@ type syncFlags struct { } func (f *syncFlags) syncOptionsFromBundle(cmd *cobra.Command, b *bundle.Bundle) (*sync.SyncOptions, error) { - cacheDir, err := b.CacheDir() + cacheDir, err := b.CacheDir(cmd.Context()) if err != nil { return nil, fmt.Errorf("cannot get bundle cache directory: %w", err) } - includes, err := b.GetSyncIncludePatterns() + includes, err := b.GetSyncIncludePatterns(cmd.Context()) if err != nil { return nil, fmt.Errorf("cannot get list of sync includes: %w", err) } diff --git a/cmd/cmd.go b/cmd/cmd.go index 032fde5c..6dd0f6e2 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "strings" "github.com/databricks/cli/cmd/account" @@ -21,8 +22,8 @@ const ( permissionsGroup = "permissions" ) -func New() *cobra.Command { - cli := root.New() +func New(ctx context.Context) *cobra.Command { + cli := root.New(ctx) // Add account subcommand. cli.AddCommand(account.New()) diff --git a/cmd/configure/configure_test.go b/cmd/configure/configure_test.go index e1ebe916..cf0505ed 100644 --- a/cmd/configure/configure_test.go +++ b/cmd/configure/configure_test.go @@ -54,7 +54,7 @@ func TestDefaultConfigureNoInteractive(t *testing.T) { }) os.Stdin = inp - cmd := cmd.New() + cmd := cmd.New(ctx) cmd.SetArgs([]string{"configure", "--token", "--host", "https://host"}) err := cmd.ExecuteContext(ctx) @@ -87,7 +87,7 @@ func TestConfigFileFromEnvNoInteractive(t *testing.T) { t.Cleanup(func() { os.Stdin = oldStdin }) os.Stdin = inp - cmd := cmd.New() + cmd := cmd.New(ctx) cmd.SetArgs([]string{"configure", "--token", "--host", "https://host"}) err := cmd.ExecuteContext(ctx) @@ -116,7 +116,7 @@ func TestCustomProfileConfigureNoInteractive(t *testing.T) { t.Cleanup(func() { os.Stdin = oldStdin }) os.Stdin = inp - cmd := cmd.New() + cmd := cmd.New(ctx) cmd.SetArgs([]string{"configure", "--token", "--host", "https://host", "--profile", "CUSTOM"}) err := cmd.ExecuteContext(ctx) diff --git a/cmd/root/bundle.go b/cmd/root/bundle.go index 10cce67a..3f9d90db 100644 --- a/cmd/root/bundle.go +++ b/cmd/root/bundle.go @@ -2,17 +2,15 @@ package root import ( "context" - "os" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config/mutator" + "github.com/databricks/cli/bundle/env" + envlib "github.com/databricks/cli/libs/env" "github.com/spf13/cobra" "golang.org/x/exp/maps" ) -const envName = "DATABRICKS_BUNDLE_ENV" -const targetName = "DATABRICKS_BUNDLE_TARGET" - // getTarget returns the name of the target to operate in. func getTarget(cmd *cobra.Command) (value string) { // The command line flag takes precedence. @@ -33,13 +31,7 @@ func getTarget(cmd *cobra.Command) (value string) { } // If it's not set, use the environment variable. - target := os.Getenv(targetName) - // If target env is not set with a new variable, try to check for old variable name - // TODO: remove when environments section is not supported anymore - if target == "" { - target = os.Getenv(envName) - } - + target, _ := env.Target(cmd.Context()) return target } @@ -54,7 +46,7 @@ func getProfile(cmd *cobra.Command) (value string) { } // If it's not set, use the environment variable. - return os.Getenv("DATABRICKS_CONFIG_PROFILE") + return envlib.Get(cmd.Context(), "DATABRICKS_CONFIG_PROFILE") } // loadBundle loads the bundle configuration and applies default mutators. diff --git a/cmd/root/io.go b/cmd/root/io.go index 380c01b1..23c7d6c6 100644 --- a/cmd/root/io.go +++ b/cmd/root/io.go @@ -1,9 +1,8 @@ package root import ( - "os" - "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/flags" "github.com/spf13/cobra" ) @@ -21,7 +20,7 @@ func initOutputFlag(cmd *cobra.Command) *outputFlag { // Configure defaults from environment, if applicable. // If the provided value is invalid it is ignored. - if v, ok := os.LookupEnv(envOutputFormat); ok { + if v, ok := env.Lookup(cmd.Context(), envOutputFormat); ok { f.output.Set(v) } diff --git a/cmd/root/logger.go b/cmd/root/logger.go index ddfae445..dca07ca4 100644 --- a/cmd/root/logger.go +++ b/cmd/root/logger.go @@ -5,9 +5,9 @@ import ( "fmt" "io" "log/slog" - "os" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/log" "github.com/fatih/color" @@ -126,13 +126,13 @@ func initLogFlags(cmd *cobra.Command) *logFlags { // Configure defaults from environment, if applicable. // If the provided value is invalid it is ignored. - if v, ok := os.LookupEnv(envLogFile); ok { + if v, ok := env.Lookup(cmd.Context(), envLogFile); ok { f.file.Set(v) } - if v, ok := os.LookupEnv(envLogLevel); ok { + if v, ok := env.Lookup(cmd.Context(), envLogLevel); ok { f.level.Set(v) } - if v, ok := os.LookupEnv(envLogFormat); ok { + if v, ok := env.Lookup(cmd.Context(), envLogFormat); ok { f.output.Set(v) } diff --git a/cmd/root/progress_logger.go b/cmd/root/progress_logger.go index bdf52558..328b9947 100644 --- a/cmd/root/progress_logger.go +++ b/cmd/root/progress_logger.go @@ -6,6 +6,7 @@ import ( "os" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/flags" "github.com/spf13/cobra" "golang.org/x/term" @@ -51,7 +52,7 @@ func initProgressLoggerFlag(cmd *cobra.Command, logFlags *logFlags) *progressLog // Configure defaults from environment, if applicable. // If the provided value is invalid it is ignored. - if v, ok := os.LookupEnv(envProgressFormat); ok { + if v, ok := env.Lookup(cmd.Context(), envProgressFormat); ok { f.Set(v) } diff --git a/cmd/root/root.go b/cmd/root/root.go index c71cf9ea..38eb42cc 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -14,7 +14,7 @@ import ( "github.com/spf13/cobra" ) -func New() *cobra.Command { +func New(ctx context.Context) *cobra.Command { cmd := &cobra.Command{ Use: "databricks", Short: "Databricks CLI", @@ -30,6 +30,10 @@ func New() *cobra.Command { SilenceErrors: true, } + // Pass the context along through the command during initialization. + // It will be overwritten when the command is executed. + cmd.SetContext(ctx) + // Initialize flags logFlags := initLogFlags(cmd) progressLoggerFlag := initProgressLoggerFlag(cmd, logFlags) diff --git a/cmd/root/user_agent_upstream.go b/cmd/root/user_agent_upstream.go index 3e173bda..f580b426 100644 --- a/cmd/root/user_agent_upstream.go +++ b/cmd/root/user_agent_upstream.go @@ -2,8 +2,8 @@ package root import ( "context" - "os" + "github.com/databricks/cli/libs/env" "github.com/databricks/databricks-sdk-go/useragent" ) @@ -16,7 +16,7 @@ const upstreamKey = "upstream" const upstreamVersionKey = "upstream-version" func withUpstreamInUserAgent(ctx context.Context) context.Context { - value := os.Getenv(upstreamEnvVar) + value := env.Get(ctx, upstreamEnvVar) if value == "" { return ctx } @@ -24,7 +24,7 @@ func withUpstreamInUserAgent(ctx context.Context) context.Context { ctx = useragent.InContext(ctx, upstreamKey, value) // Include upstream version as well, if set. - value = os.Getenv(upstreamVersionEnvVar) + value = env.Get(ctx, upstreamVersionEnvVar) if value == "" { return ctx } diff --git a/cmd/sync/sync.go b/cmd/sync/sync.go index 4a62123b..5fdfb169 100644 --- a/cmd/sync/sync.go +++ b/cmd/sync/sync.go @@ -30,12 +30,12 @@ func (f *syncFlags) syncOptionsFromBundle(cmd *cobra.Command, args []string, b * return nil, fmt.Errorf("SRC and DST are not configurable in the context of a bundle") } - cacheDir, err := b.CacheDir() + cacheDir, err := b.CacheDir(cmd.Context()) if err != nil { return nil, fmt.Errorf("cannot get bundle cache directory: %w", err) } - includes, err := b.GetSyncIncludePatterns() + includes, err := b.GetSyncIncludePatterns(cmd.Context()) if err != nil { return nil, fmt.Errorf("cannot get list of sync includes: %w", err) } diff --git a/internal/helpers.go b/internal/helpers.go index bf27fbb5..68c00019 100644 --- a/internal/helpers.go +++ b/internal/helpers.go @@ -118,7 +118,7 @@ func (t *cobraTestRunner) RunBackground() { var stdoutW, stderrW io.WriteCloser stdoutR, stdoutW = io.Pipe() stderrR, stderrW = io.Pipe() - root := cmd.New() + root := cmd.New(context.Background()) root.SetOut(stdoutW) root.SetErr(stderrW) root.SetArgs(t.args) diff --git a/internal/testutil/env.go b/internal/testutil/env.go new file mode 100644 index 00000000..05ffaf00 --- /dev/null +++ b/internal/testutil/env.go @@ -0,0 +1,33 @@ +package testutil + +import ( + "os" + "strings" + "testing" +) + +// CleanupEnvironment sets up a pristine environment containing only $PATH and $HOME. +// The original environment is restored upon test completion. +// Note: use of this function is incompatible with parallel execution. +func CleanupEnvironment(t *testing.T) { + // Restore environment when test finishes. + environ := os.Environ() + t.Cleanup(func() { + // Restore original environment. + for _, kv := range environ { + kvs := strings.SplitN(kv, "=", 2) + os.Setenv(kvs[0], kvs[1]) + } + }) + + path := os.Getenv("PATH") + pwd := os.Getenv("PWD") + os.Clearenv() + + // We use t.Setenv instead of os.Setenv because the former actively + // prevents a test being run with t.Parallel. Modifying the environment + // within a test is not compatible with running tests in parallel + // because of isolation; the environment is scoped to the process. + t.Setenv("PATH", path) + t.Setenv("HOME", pwd) +} diff --git a/libs/env/context.go b/libs/env/context.go new file mode 100644 index 00000000..cf04c1ec --- /dev/null +++ b/libs/env/context.go @@ -0,0 +1,63 @@ +package env + +import ( + "context" + "os" +) + +var envContextKey int + +func copyMap(m map[string]string) map[string]string { + out := make(map[string]string, len(m)) + for k, v := range m { + out[k] = v + } + return out +} + +func getMap(ctx context.Context) map[string]string { + if ctx == nil { + return nil + } + m, ok := ctx.Value(&envContextKey).(map[string]string) + if !ok { + return nil + } + return m +} + +func setMap(ctx context.Context, m map[string]string) context.Context { + return context.WithValue(ctx, &envContextKey, m) +} + +// Lookup key in the context or the the environment. +// Context has precedence. +func Lookup(ctx context.Context, key string) (string, bool) { + m := getMap(ctx) + + // Return if the key is set in the context. + v, ok := m[key] + if ok { + return v, true + } + + // Fall back to the environment. + return os.LookupEnv(key) +} + +// Get key from the context or the environment. +// Context has precedence. +func Get(ctx context.Context, key string) string { + v, _ := Lookup(ctx, key) + return v +} + +// Set key on the context. +// +// Note: this does NOT mutate the processes' actual environment variables. +// It is only visible to other code that uses this package. +func Set(ctx context.Context, key, value string) context.Context { + m := copyMap(getMap(ctx)) + m[key] = value + return setMap(ctx, m) +} diff --git a/libs/env/context_test.go b/libs/env/context_test.go new file mode 100644 index 00000000..9ff19459 --- /dev/null +++ b/libs/env/context_test.go @@ -0,0 +1,41 @@ +package env + +import ( + "context" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/stretchr/testify/assert" +) + +func TestContext(t *testing.T) { + testutil.CleanupEnvironment(t) + t.Setenv("FOO", "bar") + + ctx0 := context.Background() + + // Get + assert.Equal(t, "bar", Get(ctx0, "FOO")) + assert.Equal(t, "", Get(ctx0, "dontexist")) + + // Lookup + v, ok := Lookup(ctx0, "FOO") + assert.True(t, ok) + assert.Equal(t, "bar", v) + v, ok = Lookup(ctx0, "dontexist") + assert.False(t, ok) + assert.Equal(t, "", v) + + // Set and get new context. + // Verify that the previous context remains unchanged. + ctx1 := Set(ctx0, "FOO", "baz") + assert.Equal(t, "baz", Get(ctx1, "FOO")) + assert.Equal(t, "bar", Get(ctx0, "FOO")) + + // Set and get new context. + // Verify that the previous contexts remains unchanged. + ctx2 := Set(ctx1, "FOO", "qux") + assert.Equal(t, "qux", Get(ctx2, "FOO")) + assert.Equal(t, "baz", Get(ctx1, "FOO")) + assert.Equal(t, "bar", Get(ctx0, "FOO")) +} diff --git a/libs/env/pkg.go b/libs/env/pkg.go new file mode 100644 index 00000000..e0be7e22 --- /dev/null +++ b/libs/env/pkg.go @@ -0,0 +1,7 @@ +package env + +// The env package provides functions for working with environment variables +// and allowing for overrides via the context.Context. This is useful for +// testing where tainting a processes' environment is at odds with parallelism. +// Use of a context.Context to store variable overrides means tests can be +// parallelized without worrying about environment variable interference. diff --git a/main.go b/main.go index a4b8aabd..8c8516d9 100644 --- a/main.go +++ b/main.go @@ -1,10 +1,12 @@ package main import ( + "context" + "github.com/databricks/cli/cmd" "github.com/databricks/cli/cmd/root" ) func main() { - root.Execute(cmd.New()) + root.Execute(cmd.New(context.Background())) } diff --git a/main_test.go b/main_test.go index 6a5d1944..34ecdca0 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,7 @@ package main import ( + "context" "testing" "github.com/databricks/cli/cmd" @@ -15,7 +16,7 @@ func TestCommandsDontUseUnderscoreInName(t *testing.T) { // This test lives in the main package because this is where // all commands are imported. // - queue := []*cobra.Command{cmd.New()} + queue := []*cobra.Command{cmd.New(context.Background())} for len(queue) > 0 { cmd := queue[0] assert.NotContains(t, cmd.Name(), "_")