Consolidate environment variable interaction (#747)

## Changes

There are a couple places throughout the code base where interaction
with environment variables takes place. Moreover, more than one of these
would try to read a value from more than one environment variable as
fallback (for backwards compatibility). This change consolidates those
accesses.

The majority of diffs in this change are mechanical (i.e. add an
argument or replace a call).

This change:
* Moves common environment variable lookups for bundles to
`bundles/env`.
* Adds a `libs/env` package that wraps `os.LookupEnv` and `os.Getenv`
and allows for overrides to take place in a `context.Context`. By
scoping overrides to a `context.Context` we can avoid `t.Setenv` in
testing and unlock parallel test execution for integration tests.
* Updates call sites to pass through a `context.Context` where needed.
* For bundles, introduces `DATABRICKS_BUNDLE_ROOT` as new primary
variable instead of `BUNDLE_ROOT`. This was the last environment
variable that did not use the `DATABRICKS_` prefix.

## Tests

Unit tests pass.
This commit is contained in:
Pieter Noordhuis 2023-09-11 10:18:43 +02:00 committed by GitHub
parent 9a51f72f0b
commit 4ccc70aeac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
46 changed files with 594 additions and 164 deletions

View File

@ -14,6 +14,7 @@ import (
"sync" "sync"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/env"
"github.com/databricks/cli/folders" "github.com/databricks/cli/folders"
"github.com/databricks/cli/libs/git" "github.com/databricks/cli/libs/git"
"github.com/databricks/cli/libs/locker" "github.com/databricks/cli/libs/locker"
@ -51,8 +52,6 @@ type Bundle struct {
AutoApprove bool AutoApprove bool
} }
const ExtraIncludePathsKey string = "DATABRICKS_BUNDLE_INCLUDES"
func Load(ctx context.Context, path string) (*Bundle, error) { func Load(ctx context.Context, path string) (*Bundle, error) {
bundle := &Bundle{} bundle := &Bundle{}
stat, err := os.Stat(path) stat, err := os.Stat(path)
@ -61,9 +60,9 @@ func Load(ctx context.Context, path string) (*Bundle, error) {
} }
configFile, err := config.FileNames.FindInPath(path) configFile, err := config.FileNames.FindInPath(path)
if err != nil { if err != nil {
_, hasIncludePathEnv := os.LookupEnv(ExtraIncludePathsKey) _, hasRootEnv := env.Root(ctx)
_, hasBundleRootEnv := os.LookupEnv(envBundleRoot) _, hasIncludesEnv := env.Includes(ctx)
if hasIncludePathEnv && hasBundleRootEnv && stat.IsDir() { if hasRootEnv && hasIncludesEnv && stat.IsDir() {
log.Debugf(ctx, "No bundle configuration; using bundle root: %s", path) log.Debugf(ctx, "No bundle configuration; using bundle root: %s", path)
bundle.Config = config.Root{ bundle.Config = config.Root{
Path: path, Path: path,
@ -86,7 +85,7 @@ func Load(ctx context.Context, path string) (*Bundle, error) {
// MustLoad returns a bundle configuration. // MustLoad returns a bundle configuration.
// It returns an error if a bundle was not found or could not be loaded. // It returns an error if a bundle was not found or could not be loaded.
func MustLoad(ctx context.Context) (*Bundle, error) { func MustLoad(ctx context.Context) (*Bundle, error) {
root, err := mustGetRoot() root, err := mustGetRoot(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -98,7 +97,7 @@ func MustLoad(ctx context.Context) (*Bundle, error) {
// It returns an error if a bundle was found but could not be loaded. // It returns an error if a bundle was found but could not be loaded.
// It returns a `nil` bundle if a bundle was not found. // It returns a `nil` bundle if a bundle was not found.
func TryLoad(ctx context.Context) (*Bundle, error) { func TryLoad(ctx context.Context) (*Bundle, error) {
root, err := tryGetRoot() root, err := tryGetRoot(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -124,13 +123,12 @@ func (b *Bundle) WorkspaceClient() *databricks.WorkspaceClient {
// CacheDir returns directory to use for temporary files for this bundle. // CacheDir returns directory to use for temporary files for this bundle.
// Scoped to the bundle's target. // Scoped to the bundle's target.
func (b *Bundle) CacheDir(paths ...string) (string, error) { func (b *Bundle) CacheDir(ctx context.Context, paths ...string) (string, error) {
if b.Config.Bundle.Target == "" { if b.Config.Bundle.Target == "" {
panic("target not set") panic("target not set")
} }
cacheDirName, exists := os.LookupEnv("DATABRICKS_BUNDLE_TMP") cacheDirName, exists := env.TempDir(ctx)
if !exists || cacheDirName == "" { if !exists || cacheDirName == "" {
cacheDirName = filepath.Join( cacheDirName = filepath.Join(
// Anchor at bundle root directory. // Anchor at bundle root directory.
@ -163,8 +161,8 @@ func (b *Bundle) CacheDir(paths ...string) (string, error) {
// This directory is used to store and automaticaly sync internal bundle files, such as, f.e // This directory is used to store and automaticaly sync internal bundle files, such as, f.e
// notebook trampoline files for Python wheel and etc. // notebook trampoline files for Python wheel and etc.
func (b *Bundle) InternalDir() (string, error) { func (b *Bundle) InternalDir(ctx context.Context) (string, error) {
cacheDir, err := b.CacheDir() cacheDir, err := b.CacheDir(ctx)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -181,8 +179,8 @@ func (b *Bundle) InternalDir() (string, error) {
// GetSyncIncludePatterns returns a list of user defined includes // GetSyncIncludePatterns returns a list of user defined includes
// And also adds InternalDir folder to include list for sync command // And also adds InternalDir folder to include list for sync command
// so this folder is always synced // so this folder is always synced
func (b *Bundle) GetSyncIncludePatterns() ([]string, error) { func (b *Bundle) GetSyncIncludePatterns(ctx context.Context) ([]string, error) {
internalDir, err := b.InternalDir() internalDir, err := b.InternalDir(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -6,6 +6,7 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/databricks/cli/bundle/env"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -23,12 +24,13 @@ func TestLoadExists(t *testing.T) {
} }
func TestBundleCacheDir(t *testing.T) { func TestBundleCacheDir(t *testing.T) {
ctx := context.Background()
projectDir := t.TempDir() projectDir := t.TempDir()
f1, err := os.Create(filepath.Join(projectDir, "databricks.yml")) f1, err := os.Create(filepath.Join(projectDir, "databricks.yml"))
require.NoError(t, err) require.NoError(t, err)
f1.Close() f1.Close()
bundle, err := Load(context.Background(), projectDir) bundle, err := Load(ctx, projectDir)
require.NoError(t, err) require.NoError(t, err)
// Artificially set target. // Artificially set target.
@ -38,7 +40,7 @@ func TestBundleCacheDir(t *testing.T) {
// unset env variable in case it's set // unset env variable in case it's set
t.Setenv("DATABRICKS_BUNDLE_TMP", "") t.Setenv("DATABRICKS_BUNDLE_TMP", "")
cacheDir, err := bundle.CacheDir() cacheDir, err := bundle.CacheDir(ctx)
// format is <CWD>/.databricks/bundle/<target> // format is <CWD>/.databricks/bundle/<target>
assert.NoError(t, err) assert.NoError(t, err)
@ -46,13 +48,14 @@ func TestBundleCacheDir(t *testing.T) {
} }
func TestBundleCacheDirOverride(t *testing.T) { func TestBundleCacheDirOverride(t *testing.T) {
ctx := context.Background()
projectDir := t.TempDir() projectDir := t.TempDir()
bundleTmpDir := t.TempDir() bundleTmpDir := t.TempDir()
f1, err := os.Create(filepath.Join(projectDir, "databricks.yml")) f1, err := os.Create(filepath.Join(projectDir, "databricks.yml"))
require.NoError(t, err) require.NoError(t, err)
f1.Close() f1.Close()
bundle, err := Load(context.Background(), projectDir) bundle, err := Load(ctx, projectDir)
require.NoError(t, err) require.NoError(t, err)
// Artificially set target. // Artificially set target.
@ -62,7 +65,7 @@ func TestBundleCacheDirOverride(t *testing.T) {
// now we expect to use 'bundleTmpDir' instead of CWD/.databricks/bundle // now we expect to use 'bundleTmpDir' instead of CWD/.databricks/bundle
t.Setenv("DATABRICKS_BUNDLE_TMP", bundleTmpDir) t.Setenv("DATABRICKS_BUNDLE_TMP", bundleTmpDir)
cacheDir, err := bundle.CacheDir() cacheDir, err := bundle.CacheDir(ctx)
// format is <DATABRICKS_BUNDLE_TMP>/<target> // format is <DATABRICKS_BUNDLE_TMP>/<target>
assert.NoError(t, err) assert.NoError(t, err)
@ -70,14 +73,14 @@ func TestBundleCacheDirOverride(t *testing.T) {
} }
func TestBundleMustLoadSuccess(t *testing.T) { func TestBundleMustLoadSuccess(t *testing.T) {
t.Setenv(envBundleRoot, "./tests/basic") t.Setenv(env.RootVariable, "./tests/basic")
b, err := MustLoad(context.Background()) b, err := MustLoad(context.Background())
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "tests/basic", filepath.ToSlash(b.Config.Path)) assert.Equal(t, "tests/basic", filepath.ToSlash(b.Config.Path))
} }
func TestBundleMustLoadFailureWithEnv(t *testing.T) { func TestBundleMustLoadFailureWithEnv(t *testing.T) {
t.Setenv(envBundleRoot, "./tests/doesntexist") t.Setenv(env.RootVariable, "./tests/doesntexist")
_, err := MustLoad(context.Background()) _, err := MustLoad(context.Background())
require.Error(t, err, "not a directory") require.Error(t, err, "not a directory")
} }
@ -89,14 +92,14 @@ func TestBundleMustLoadFailureIfNotFound(t *testing.T) {
} }
func TestBundleTryLoadSuccess(t *testing.T) { func TestBundleTryLoadSuccess(t *testing.T) {
t.Setenv(envBundleRoot, "./tests/basic") t.Setenv(env.RootVariable, "./tests/basic")
b, err := TryLoad(context.Background()) b, err := TryLoad(context.Background())
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "tests/basic", filepath.ToSlash(b.Config.Path)) assert.Equal(t, "tests/basic", filepath.ToSlash(b.Config.Path))
} }
func TestBundleTryLoadFailureWithEnv(t *testing.T) { func TestBundleTryLoadFailureWithEnv(t *testing.T) {
t.Setenv(envBundleRoot, "./tests/doesntexist") t.Setenv(env.RootVariable, "./tests/doesntexist")
_, err := TryLoad(context.Background()) _, err := TryLoad(context.Background())
require.Error(t, err, "not a directory") require.Error(t, err, "not a directory")
} }

View File

@ -3,11 +3,11 @@ package mutator
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources" "github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/cli/libs/env"
) )
type overrideCompute struct{} type overrideCompute struct{}
@ -39,8 +39,8 @@ func (m *overrideCompute) Apply(ctx context.Context, b *bundle.Bundle) error {
} }
return nil return nil
} }
if os.Getenv("DATABRICKS_CLUSTER_ID") != "" { if v := env.Get(ctx, "DATABRICKS_CLUSTER_ID"); v != "" {
b.Config.Bundle.ComputeID = os.Getenv("DATABRICKS_CLUSTER_ID") b.Config.Bundle.ComputeID = v
} }
if b.Config.Bundle.ComputeID == "" { if b.Config.Bundle.ComputeID == "" {

View File

@ -10,11 +10,12 @@ import (
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/env"
) )
// Get extra include paths from environment variable // Get extra include paths from environment variable
func GetExtraIncludePaths() []string { func getExtraIncludePaths(ctx context.Context) []string {
value, exists := os.LookupEnv(bundle.ExtraIncludePathsKey) value, exists := env.Includes(ctx)
if !exists { if !exists {
return nil return nil
} }
@ -48,7 +49,7 @@ func (m *processRootIncludes) Apply(ctx context.Context, b *bundle.Bundle) error
var files []string var files []string
// Converts extra include paths from environment variable to relative paths // Converts extra include paths from environment variable to relative paths
for _, extraIncludePath := range GetExtraIncludePaths() { for _, extraIncludePath := range getExtraIncludePaths(ctx) {
if filepath.IsAbs(extraIncludePath) { if filepath.IsAbs(extraIncludePath) {
rel, err := filepath.Rel(b.Config.Path, extraIncludePath) rel, err := filepath.Rel(b.Config.Path, extraIncludePath)
if err != nil { if err != nil {

View File

@ -2,16 +2,17 @@ package mutator_test
import ( import (
"context" "context"
"fmt"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings"
"testing" "testing"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/mutator" "github.com/databricks/cli/bundle/config/mutator"
"github.com/databricks/cli/bundle/env"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -129,10 +130,7 @@ func TestProcessRootIncludesExtrasFromEnvVar(t *testing.T) {
rootPath := t.TempDir() rootPath := t.TempDir()
testYamlName := "extra_include_path.yml" testYamlName := "extra_include_path.yml"
touch(t, rootPath, testYamlName) touch(t, rootPath, testYamlName)
os.Setenv(bundle.ExtraIncludePathsKey, path.Join(rootPath, testYamlName)) t.Setenv(env.IncludesVariable, path.Join(rootPath, testYamlName))
t.Cleanup(func() {
os.Unsetenv(bundle.ExtraIncludePathsKey)
})
bundle := &bundle.Bundle{ bundle := &bundle.Bundle{
Config: config.Root{ Config: config.Root{
@ -149,7 +147,13 @@ func TestProcessRootIncludesDedupExtrasFromEnvVar(t *testing.T) {
rootPath := t.TempDir() rootPath := t.TempDir()
testYamlName := "extra_include_path.yml" testYamlName := "extra_include_path.yml"
touch(t, rootPath, testYamlName) touch(t, rootPath, testYamlName)
t.Setenv(bundle.ExtraIncludePathsKey, fmt.Sprintf("%s%s%s", path.Join(rootPath, testYamlName), string(os.PathListSeparator), path.Join(rootPath, testYamlName))) t.Setenv(env.IncludesVariable, strings.Join(
[]string{
path.Join(rootPath, testYamlName),
path.Join(rootPath, testYamlName),
},
string(os.PathListSeparator),
))
bundle := &bundle.Bundle{ bundle := &bundle.Bundle{
Config: config.Root{ Config: config.Root{

View File

@ -3,10 +3,10 @@ package mutator
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config/variable" "github.com/databricks/cli/bundle/config/variable"
"github.com/databricks/cli/libs/env"
) )
const bundleVarPrefix = "BUNDLE_VAR_" const bundleVarPrefix = "BUNDLE_VAR_"
@ -21,7 +21,7 @@ func (m *setVariables) Name() string {
return "SetVariables" return "SetVariables"
} }
func setVariable(v *variable.Variable, name string) error { func setVariable(ctx context.Context, v *variable.Variable, name string) error {
// case: variable already has value initialized, so skip // case: variable already has value initialized, so skip
if v.HasValue() { if v.HasValue() {
return nil return nil
@ -29,7 +29,7 @@ func setVariable(v *variable.Variable, name string) error {
// case: read and set variable value from process environment // case: read and set variable value from process environment
envVarName := bundleVarPrefix + name envVarName := bundleVarPrefix + name
if val, ok := os.LookupEnv(envVarName); ok { if val, ok := env.Lookup(ctx, envVarName); ok {
err := v.Set(val) err := v.Set(val)
if err != nil { if err != nil {
return fmt.Errorf(`failed to assign value "%s" to variable %s from environment variable %s with error: %w`, val, name, envVarName, err) return fmt.Errorf(`failed to assign value "%s" to variable %s from environment variable %s with error: %w`, val, name, envVarName, err)
@ -54,7 +54,7 @@ func setVariable(v *variable.Variable, name string) error {
func (m *setVariables) Apply(ctx context.Context, b *bundle.Bundle) error { func (m *setVariables) Apply(ctx context.Context, b *bundle.Bundle) error {
for name, variable := range b.Config.Variables { for name, variable := range b.Config.Variables {
err := setVariable(variable, name) err := setVariable(ctx, variable, name)
if err != nil { if err != nil {
return err return err
} }

View File

@ -21,7 +21,7 @@ func TestSetVariableFromProcessEnvVar(t *testing.T) {
// set value for variable as an environment variable // set value for variable as an environment variable
t.Setenv("BUNDLE_VAR_foo", "process-env") t.Setenv("BUNDLE_VAR_foo", "process-env")
err := setVariable(&variable, "foo") err := setVariable(context.Background(), &variable, "foo")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, *variable.Value, "process-env") assert.Equal(t, *variable.Value, "process-env")
} }
@ -33,7 +33,7 @@ func TestSetVariableUsingDefaultValue(t *testing.T) {
Default: &defaultVal, Default: &defaultVal,
} }
err := setVariable(&variable, "foo") err := setVariable(context.Background(), &variable, "foo")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, *variable.Value, "default") assert.Equal(t, *variable.Value, "default")
} }
@ -49,7 +49,7 @@ func TestSetVariableWhenAlreadyAValueIsAssigned(t *testing.T) {
// since a value is already assigned to the variable, it would not be overridden // since a value is already assigned to the variable, it would not be overridden
// by the default value // by the default value
err := setVariable(&variable, "foo") err := setVariable(context.Background(), &variable, "foo")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, *variable.Value, "assigned-value") assert.Equal(t, *variable.Value, "assigned-value")
} }
@ -68,7 +68,7 @@ func TestSetVariableEnvVarValueDoesNotOverridePresetValue(t *testing.T) {
// since a value is already assigned to the variable, it would not be overridden // since a value is already assigned to the variable, it would not be overridden
// by the value from environment // by the value from environment
err := setVariable(&variable, "foo") err := setVariable(context.Background(), &variable, "foo")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, *variable.Value, "assigned-value") assert.Equal(t, *variable.Value, "assigned-value")
} }
@ -79,7 +79,7 @@ func TestSetVariablesErrorsIfAValueCouldNotBeResolved(t *testing.T) {
} }
// fails because we could not resolve a value for the variable // fails because we could not resolve a value for the variable
err := setVariable(&variable, "foo") err := setVariable(context.Background(), &variable, "foo")
assert.ErrorContains(t, err, "no value assigned to required variable foo. Assignment can be done through the \"--var\" flag or by setting the BUNDLE_VAR_foo environment variable") assert.ErrorContains(t, err, "no value assigned to required variable foo. Assignment can be done through the \"--var\" flag or by setting the BUNDLE_VAR_foo environment variable")
} }

View File

@ -43,7 +43,7 @@ func (m *trampoline) Name() string {
func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error { func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error {
tasks := m.functions.GetTasks(b) tasks := m.functions.GetTasks(b)
for _, task := range tasks { for _, task := range tasks {
err := m.generateNotebookWrapper(b, task) err := m.generateNotebookWrapper(ctx, b, task)
if err != nil { if err != nil {
return err return err
} }
@ -51,8 +51,8 @@ func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error {
return nil return nil
} }
func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task TaskWithJobKey) error { func (m *trampoline) generateNotebookWrapper(ctx context.Context, b *bundle.Bundle, task TaskWithJobKey) error {
internalDir, err := b.InternalDir() internalDir, err := b.InternalDir(ctx)
if err != nil { if err != nil {
return err return err
} }

View File

@ -83,7 +83,7 @@ func TestGenerateTrampoline(t *testing.T) {
err := bundle.Apply(ctx, b, trampoline) err := bundle.Apply(ctx, b, trampoline)
require.NoError(t, err) require.NoError(t, err)
dir, err := b.InternalDir() dir, err := b.InternalDir(ctx)
require.NoError(t, err) require.NoError(t, err)
filename := filepath.Join(dir, "notebook_test_to_trampoline.py") filename := filepath.Join(dir, "notebook_test_to_trampoline.py")

View File

@ -9,12 +9,12 @@ import (
) )
func getSync(ctx context.Context, b *bundle.Bundle) (*sync.Sync, error) { func getSync(ctx context.Context, b *bundle.Bundle) (*sync.Sync, error) {
cacheDir, err := b.CacheDir() cacheDir, err := b.CacheDir(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot get bundle cache directory: %w", err) return nil, fmt.Errorf("cannot get bundle cache directory: %w", err)
} }
includes, err := b.GetSyncIncludePatterns() includes, err := b.GetSyncIncludePatterns(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot get list of sync includes: %w", err) return nil, fmt.Errorf("cannot get list of sync includes: %w", err)
} }

View File

@ -1,11 +1,13 @@
package terraform package terraform
import ( import (
"context"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
) )
// Dir returns the Terraform working directory for a given bundle. // Dir returns the Terraform working directory for a given bundle.
// The working directory is emphemeral and nested under the bundle's cache directory. // The working directory is emphemeral and nested under the bundle's cache directory.
func Dir(b *bundle.Bundle) (string, error) { func Dir(ctx context.Context, b *bundle.Bundle) (string, error) {
return b.CacheDir("terraform") return b.CacheDir(ctx, "terraform")
} }

View File

@ -12,6 +12,7 @@ import (
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/log" "github.com/databricks/cli/libs/log"
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"github.com/hashicorp/hc-install/product" "github.com/hashicorp/hc-install/product"
@ -38,7 +39,7 @@ func (m *initialize) findExecPath(ctx context.Context, b *bundle.Bundle, tf *con
return tf.ExecPath, nil return tf.ExecPath, nil
} }
binDir, err := b.CacheDir("bin") binDir, err := b.CacheDir(context.Background(), "bin")
if err != nil { if err != nil {
return "", err return "", err
} }
@ -73,25 +74,25 @@ func (m *initialize) findExecPath(ctx context.Context, b *bundle.Bundle, tf *con
} }
// This function inherits some environment variables for Terraform CLI. // This function inherits some environment variables for Terraform CLI.
func inheritEnvVars(env map[string]string) error { func inheritEnvVars(ctx context.Context, environ map[string]string) error {
// Include $HOME in set of environment variables to pass along. // Include $HOME in set of environment variables to pass along.
home, ok := os.LookupEnv("HOME") home, ok := env.Lookup(ctx, "HOME")
if ok { if ok {
env["HOME"] = home environ["HOME"] = home
} }
// Include $PATH in set of environment variables to pass along. // Include $PATH in set of environment variables to pass along.
// This is necessary to ensure that our Terraform provider can use the // This is necessary to ensure that our Terraform provider can use the
// same auxiliary programs (e.g. `az`, or `gcloud`) as the CLI. // same auxiliary programs (e.g. `az`, or `gcloud`) as the CLI.
path, ok := os.LookupEnv("PATH") path, ok := env.Lookup(ctx, "PATH")
if ok { if ok {
env["PATH"] = path environ["PATH"] = path
} }
// Include $TF_CLI_CONFIG_FILE to override terraform provider in development. // Include $TF_CLI_CONFIG_FILE to override terraform provider in development.
configFile, ok := os.LookupEnv("TF_CLI_CONFIG_FILE") configFile, ok := env.Lookup(ctx, "TF_CLI_CONFIG_FILE")
if ok { if ok {
env["TF_CLI_CONFIG_FILE"] = configFile environ["TF_CLI_CONFIG_FILE"] = configFile
} }
return nil return nil
@ -105,40 +106,40 @@ func inheritEnvVars(env map[string]string) error {
// the CLI and its dependencies do not have access to. // the CLI and its dependencies do not have access to.
// //
// see: os.TempDir for more context // see: os.TempDir for more context
func setTempDirEnvVars(env map[string]string, b *bundle.Bundle) error { func setTempDirEnvVars(ctx context.Context, environ map[string]string, b *bundle.Bundle) error {
switch runtime.GOOS { switch runtime.GOOS {
case "windows": case "windows":
if v, ok := os.LookupEnv("TMP"); ok { if v, ok := env.Lookup(ctx, "TMP"); ok {
env["TMP"] = v environ["TMP"] = v
} else if v, ok := os.LookupEnv("TEMP"); ok { } else if v, ok := env.Lookup(ctx, "TEMP"); ok {
env["TEMP"] = v environ["TEMP"] = v
} else if v, ok := os.LookupEnv("USERPROFILE"); ok { } else if v, ok := env.Lookup(ctx, "USERPROFILE"); ok {
env["USERPROFILE"] = v environ["USERPROFILE"] = v
} else { } else {
tmpDir, err := b.CacheDir("tmp") tmpDir, err := b.CacheDir(ctx, "tmp")
if err != nil { if err != nil {
return err return err
} }
env["TMP"] = tmpDir environ["TMP"] = tmpDir
} }
default: default:
// If TMPDIR is not set, we let the process fall back to its default value. // If TMPDIR is not set, we let the process fall back to its default value.
if v, ok := os.LookupEnv("TMPDIR"); ok { if v, ok := env.Lookup(ctx, "TMPDIR"); ok {
env["TMPDIR"] = v environ["TMPDIR"] = v
} }
} }
return nil return nil
} }
// This function passes through all proxy related environment variables. // This function passes through all proxy related environment variables.
func setProxyEnvVars(env map[string]string, b *bundle.Bundle) error { func setProxyEnvVars(ctx context.Context, environ map[string]string, b *bundle.Bundle) error {
for _, v := range []string{"http_proxy", "https_proxy", "no_proxy"} { for _, v := range []string{"http_proxy", "https_proxy", "no_proxy"} {
// The case (upper or lower) is notoriously inconsistent for tools on Unix systems. // The case (upper or lower) is notoriously inconsistent for tools on Unix systems.
// We therefore try to read both the upper and lower case versions of the variable. // We therefore try to read both the upper and lower case versions of the variable.
for _, v := range []string{strings.ToUpper(v), strings.ToLower(v)} { for _, v := range []string{strings.ToUpper(v), strings.ToLower(v)} {
if val, ok := os.LookupEnv(v); ok { if val, ok := env.Lookup(ctx, v); ok {
// Only set uppercase version of the variable. // Only set uppercase version of the variable.
env[strings.ToUpper(v)] = val environ[strings.ToUpper(v)] = val
} }
} }
} }
@ -157,7 +158,7 @@ func (m *initialize) Apply(ctx context.Context, b *bundle.Bundle) error {
return err return err
} }
workingDir, err := Dir(b) workingDir, err := Dir(ctx, b)
if err != nil { if err != nil {
return err return err
} }
@ -167,31 +168,31 @@ func (m *initialize) Apply(ctx context.Context, b *bundle.Bundle) error {
return err return err
} }
env, err := b.AuthEnv() environ, err := b.AuthEnv()
if err != nil { if err != nil {
return err return err
} }
err = inheritEnvVars(env) err = inheritEnvVars(ctx, environ)
if err != nil { if err != nil {
return err return err
} }
// Set the temporary directory environment variables // Set the temporary directory environment variables
err = setTempDirEnvVars(env, b) err = setTempDirEnvVars(ctx, environ, b)
if err != nil { if err != nil {
return err return err
} }
// Set the proxy related environment variables // Set the proxy related environment variables
err = setProxyEnvVars(env, b) err = setProxyEnvVars(ctx, environ, b)
if err != nil { if err != nil {
return err return err
} }
// Configure environment variables for auth for Terraform to use. // Configure environment variables for auth for Terraform to use.
log.Debugf(ctx, "Environment variables for Terraform: %s", strings.Join(maps.Keys(env), ", ")) log.Debugf(ctx, "Environment variables for Terraform: %s", strings.Join(maps.Keys(environ), ", "))
err = tf.SetEnv(env) err = tf.SetEnv(environ)
if err != nil { if err != nil {
return err return err
} }

View File

@ -68,7 +68,7 @@ func TestSetTempDirEnvVarsForUnixWithTmpDirSet(t *testing.T) {
// compute env // compute env
env := make(map[string]string, 0) env := make(map[string]string, 0)
err := setTempDirEnvVars(env, b) err := setTempDirEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
// Assert that we pass through TMPDIR. // Assert that we pass through TMPDIR.
@ -96,7 +96,7 @@ func TestSetTempDirEnvVarsForUnixWithTmpDirNotSet(t *testing.T) {
// compute env // compute env
env := make(map[string]string, 0) env := make(map[string]string, 0)
err := setTempDirEnvVars(env, b) err := setTempDirEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
// Assert that we don't pass through TMPDIR. // Assert that we don't pass through TMPDIR.
@ -124,7 +124,7 @@ func TestSetTempDirEnvVarsForWindowWithAllTmpDirEnvVarsSet(t *testing.T) {
// compute env // compute env
env := make(map[string]string, 0) env := make(map[string]string, 0)
err := setTempDirEnvVars(env, b) err := setTempDirEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
// assert that we pass through the highest priority env var value // assert that we pass through the highest priority env var value
@ -154,7 +154,7 @@ func TestSetTempDirEnvVarsForWindowWithUserProfileAndTempSet(t *testing.T) {
// compute env // compute env
env := make(map[string]string, 0) env := make(map[string]string, 0)
err := setTempDirEnvVars(env, b) err := setTempDirEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
// assert that we pass through the highest priority env var value // assert that we pass through the highest priority env var value
@ -184,7 +184,7 @@ func TestSetTempDirEnvVarsForWindowWithUserProfileSet(t *testing.T) {
// compute env // compute env
env := make(map[string]string, 0) env := make(map[string]string, 0)
err := setTempDirEnvVars(env, b) err := setTempDirEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
// assert that we pass through the user profile // assert that we pass through the user profile
@ -214,11 +214,11 @@ func TestSetTempDirEnvVarsForWindowsWithoutAnyTempDirEnvVarsSet(t *testing.T) {
// compute env // compute env
env := make(map[string]string, 0) env := make(map[string]string, 0)
err := setTempDirEnvVars(env, b) err := setTempDirEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
// assert TMP is set to b.CacheDir("tmp") // assert TMP is set to b.CacheDir("tmp")
tmpDir, err := b.CacheDir("tmp") tmpDir, err := b.CacheDir(context.Background(), "tmp")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, map[string]string{ assert.Equal(t, map[string]string{
"TMP": tmpDir, "TMP": tmpDir,
@ -248,7 +248,7 @@ func TestSetProxyEnvVars(t *testing.T) {
// No proxy env vars set. // No proxy env vars set.
clearEnv() clearEnv()
env := make(map[string]string, 0) env := make(map[string]string, 0)
err := setProxyEnvVars(env, b) err := setProxyEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, env, 0) assert.Len(t, env, 0)
@ -258,7 +258,7 @@ func TestSetProxyEnvVars(t *testing.T) {
t.Setenv("https_proxy", "foo") t.Setenv("https_proxy", "foo")
t.Setenv("no_proxy", "foo") t.Setenv("no_proxy", "foo")
env = make(map[string]string, 0) env = make(map[string]string, 0)
err = setProxyEnvVars(env, b) err = setProxyEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
assert.ElementsMatch(t, []string{"HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY"}, maps.Keys(env)) assert.ElementsMatch(t, []string{"HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY"}, maps.Keys(env))
@ -268,7 +268,7 @@ func TestSetProxyEnvVars(t *testing.T) {
t.Setenv("HTTPS_PROXY", "foo") t.Setenv("HTTPS_PROXY", "foo")
t.Setenv("NO_PROXY", "foo") t.Setenv("NO_PROXY", "foo")
env = make(map[string]string, 0) env = make(map[string]string, 0)
err = setProxyEnvVars(env, b) err = setProxyEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
assert.ElementsMatch(t, []string{"HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY"}, maps.Keys(env)) assert.ElementsMatch(t, []string{"HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY"}, maps.Keys(env))
} }
@ -280,7 +280,7 @@ func TestInheritEnvVars(t *testing.T) {
t.Setenv("PATH", "/foo:/bar") t.Setenv("PATH", "/foo:/bar")
t.Setenv("TF_CLI_CONFIG_FILE", "/tmp/config.tfrc") t.Setenv("TF_CLI_CONFIG_FILE", "/tmp/config.tfrc")
err := inheritEnvVars(env) err := inheritEnvVars(context.Background(), env)
require.NoError(t, err) require.NoError(t, err)

View File

@ -40,7 +40,7 @@ func (p *plan) Apply(ctx context.Context, b *bundle.Bundle) error {
} }
// Persist computed plan // Persist computed plan
tfDir, err := Dir(b) tfDir, err := Dir(ctx, b)
if err != nil { if err != nil {
return err return err
} }

View File

@ -25,7 +25,7 @@ func (l *statePull) Apply(ctx context.Context, b *bundle.Bundle) error {
return err return err
} }
dir, err := Dir(b) dir, err := Dir(ctx, b)
if err != nil { if err != nil {
return err return err
} }

View File

@ -22,7 +22,7 @@ func (l *statePush) Apply(ctx context.Context, b *bundle.Bundle) error {
return err return err
} }
dir, err := Dir(b) dir, err := Dir(ctx, b)
if err != nil { if err != nil {
return err return err
} }

View File

@ -16,7 +16,7 @@ func (w *write) Name() string {
} }
func (w *write) Apply(ctx context.Context, b *bundle.Bundle) error { func (w *write) Apply(ctx context.Context, b *bundle.Bundle) error {
dir, err := Dir(b) dir, err := Dir(ctx, b)
if err != nil { if err != nil {
return err return err
} }

18
bundle/env/env.go vendored Normal file
View File

@ -0,0 +1,18 @@
package env
import (
"context"
envlib "github.com/databricks/cli/libs/env"
)
// Return the value of the first environment variable that is set.
func get(ctx context.Context, variables []string) (string, bool) {
for _, v := range variables {
value, ok := envlib.Lookup(ctx, v)
if ok {
return value, true
}
}
return "", false
}

44
bundle/env/env_test.go vendored Normal file
View File

@ -0,0 +1,44 @@
package env
import (
"context"
"testing"
"github.com/databricks/cli/internal/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetWithRealEnvSingleVariable(t *testing.T) {
testutil.CleanupEnvironment(t)
t.Setenv("v1", "foo")
v, ok := get(context.Background(), []string{"v1"})
require.True(t, ok)
assert.Equal(t, "foo", v)
// Not set.
v, ok = get(context.Background(), []string{"v2"})
require.False(t, ok)
assert.Equal(t, "", v)
}
func TestGetWithRealEnvMultipleVariables(t *testing.T) {
testutil.CleanupEnvironment(t)
t.Setenv("v1", "foo")
for _, vars := range [][]string{
{"v1", "v2", "v3"},
{"v2", "v3", "v1"},
{"v3", "v1", "v2"},
} {
v, ok := get(context.Background(), vars)
require.True(t, ok)
assert.Equal(t, "foo", v)
}
// Not set.
v, ok := get(context.Background(), []string{"v2", "v3", "v4"})
require.False(t, ok)
assert.Equal(t, "", v)
}

14
bundle/env/includes.go vendored Normal file
View File

@ -0,0 +1,14 @@
package env
import "context"
// IncludesVariable names the environment variable that holds additional configuration paths to include
// during bundle configuration loading. Also see `bundle/config/mutator/process_root_includes.go`.
const IncludesVariable = "DATABRICKS_BUNDLE_INCLUDES"
// Includes returns the bundle Includes environment variable.
func Includes(ctx context.Context) (string, bool) {
return get(ctx, []string{
IncludesVariable,
})
}

28
bundle/env/includes_test.go vendored Normal file
View File

@ -0,0 +1,28 @@
package env
import (
"context"
"testing"
"github.com/databricks/cli/internal/testutil"
"github.com/stretchr/testify/assert"
)
func TestIncludes(t *testing.T) {
ctx := context.Background()
testutil.CleanupEnvironment(t)
t.Run("set", func(t *testing.T) {
t.Setenv("DATABRICKS_BUNDLE_INCLUDES", "foo")
includes, ok := Includes(ctx)
assert.True(t, ok)
assert.Equal(t, "foo", includes)
})
t.Run("not set", func(t *testing.T) {
includes, ok := Includes(ctx)
assert.False(t, ok)
assert.Equal(t, "", includes)
})
}

16
bundle/env/root.go vendored Normal file
View File

@ -0,0 +1,16 @@
package env
import "context"
// RootVariable names the environment variable that holds the bundle root path.
const RootVariable = "DATABRICKS_BUNDLE_ROOT"
// Root returns the bundle root environment variable.
func Root(ctx context.Context) (string, bool) {
return get(ctx, []string{
RootVariable,
// Primary variable name for the bundle root until v0.204.0.
"BUNDLE_ROOT",
})
}

43
bundle/env/root_test.go vendored Normal file
View File

@ -0,0 +1,43 @@
package env
import (
"context"
"testing"
"github.com/databricks/cli/internal/testutil"
"github.com/stretchr/testify/assert"
)
func TestRoot(t *testing.T) {
ctx := context.Background()
testutil.CleanupEnvironment(t)
t.Run("first", func(t *testing.T) {
t.Setenv("DATABRICKS_BUNDLE_ROOT", "foo")
root, ok := Root(ctx)
assert.True(t, ok)
assert.Equal(t, "foo", root)
})
t.Run("second", func(t *testing.T) {
t.Setenv("BUNDLE_ROOT", "foo")
root, ok := Root(ctx)
assert.True(t, ok)
assert.Equal(t, "foo", root)
})
t.Run("both set", func(t *testing.T) {
t.Setenv("DATABRICKS_BUNDLE_ROOT", "first")
t.Setenv("BUNDLE_ROOT", "second")
root, ok := Root(ctx)
assert.True(t, ok)
assert.Equal(t, "first", root)
})
t.Run("not set", func(t *testing.T) {
root, ok := Root(ctx)
assert.False(t, ok)
assert.Equal(t, "", root)
})
}

17
bundle/env/target.go vendored Normal file
View File

@ -0,0 +1,17 @@
package env
import "context"
// TargetVariable names the environment variable that holds the bundle target to use.
const TargetVariable = "DATABRICKS_BUNDLE_TARGET"
// Target returns the bundle target environment variable.
func Target(ctx context.Context) (string, bool) {
return get(ctx, []string{
TargetVariable,
// Primary variable name for the bundle target until v0.203.2.
// See https://github.com/databricks/cli/pull/670.
"DATABRICKS_BUNDLE_ENV",
})
}

43
bundle/env/target_test.go vendored Normal file
View File

@ -0,0 +1,43 @@
package env
import (
"context"
"testing"
"github.com/databricks/cli/internal/testutil"
"github.com/stretchr/testify/assert"
)
func TestTarget(t *testing.T) {
ctx := context.Background()
testutil.CleanupEnvironment(t)
t.Run("first", func(t *testing.T) {
t.Setenv("DATABRICKS_BUNDLE_TARGET", "foo")
target, ok := Target(ctx)
assert.True(t, ok)
assert.Equal(t, "foo", target)
})
t.Run("second", func(t *testing.T) {
t.Setenv("DATABRICKS_BUNDLE_ENV", "foo")
target, ok := Target(ctx)
assert.True(t, ok)
assert.Equal(t, "foo", target)
})
t.Run("both set", func(t *testing.T) {
t.Setenv("DATABRICKS_BUNDLE_TARGET", "first")
t.Setenv("DATABRICKS_BUNDLE_ENV", "second")
target, ok := Target(ctx)
assert.True(t, ok)
assert.Equal(t, "first", target)
})
t.Run("not set", func(t *testing.T) {
target, ok := Target(ctx)
assert.False(t, ok)
assert.Equal(t, "", target)
})
}

13
bundle/env/temp_dir.go vendored Normal file
View File

@ -0,0 +1,13 @@
package env
import "context"
// TempDirVariable names the environment variable that holds the temporary directory to use.
const TempDirVariable = "DATABRICKS_BUNDLE_TMP"
// TempDir returns the temporary directory to use.
func TempDir(ctx context.Context) (string, bool) {
return get(ctx, []string{
TempDirVariable,
})
}

28
bundle/env/temp_dir_test.go vendored Normal file
View File

@ -0,0 +1,28 @@
package env
import (
"context"
"testing"
"github.com/databricks/cli/internal/testutil"
"github.com/stretchr/testify/assert"
)
func TestTempDir(t *testing.T) {
ctx := context.Background()
testutil.CleanupEnvironment(t)
t.Run("set", func(t *testing.T) {
t.Setenv("DATABRICKS_BUNDLE_TMP", "foo")
tempDir, ok := TempDir(ctx)
assert.True(t, ok)
assert.Equal(t, "foo", tempDir)
})
t.Run("not set", func(t *testing.T) {
tempDir, ok := TempDir(ctx)
assert.False(t, ok)
assert.Equal(t, "", tempDir)
})
}

View File

@ -1,21 +1,21 @@
package bundle package bundle
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/env"
"github.com/databricks/cli/folders" "github.com/databricks/cli/folders"
) )
const envBundleRoot = "BUNDLE_ROOT" // getRootEnv returns the value of the bundle root environment variable
// getRootEnv returns the value of the `BUNDLE_ROOT` environment variable
// if it set and is a directory. If the environment variable is set but // if it set and is a directory. If the environment variable is set but
// is not a directory, it returns an error. If the environment variable is // is not a directory, it returns an error. If the environment variable is
// not set, it returns an empty string. // not set, it returns an empty string.
func getRootEnv() (string, error) { func getRootEnv(ctx context.Context) (string, error) {
path, ok := os.LookupEnv(envBundleRoot) path, ok := env.Root(ctx)
if !ok { if !ok {
return "", nil return "", nil
} }
@ -24,7 +24,7 @@ func getRootEnv() (string, error) {
err = fmt.Errorf("not a directory") err = fmt.Errorf("not a directory")
} }
if err != nil { if err != nil {
return "", fmt.Errorf(`invalid bundle root %s="%s": %w`, envBundleRoot, path, err) return "", fmt.Errorf(`invalid bundle root %s="%s": %w`, env.RootVariable, path, err)
} }
return path, nil return path, nil
} }
@ -48,8 +48,8 @@ func getRootWithTraversal() (string, error) {
} }
// mustGetRoot returns a bundle root or an error if one cannot be found. // mustGetRoot returns a bundle root or an error if one cannot be found.
func mustGetRoot() (string, error) { func mustGetRoot(ctx context.Context) (string, error) {
path, err := getRootEnv() path, err := getRootEnv(ctx)
if path != "" || err != nil { if path != "" || err != nil {
return path, err return path, err
} }
@ -57,9 +57,9 @@ func mustGetRoot() (string, error) {
} }
// tryGetRoot returns a bundle root or an empty string if one cannot be found. // tryGetRoot returns a bundle root or an empty string if one cannot be found.
func tryGetRoot() (string, error) { func tryGetRoot(ctx context.Context) (string, error) {
// Note: an invalid value in the environment variable is still an error. // Note: an invalid value in the environment variable is still an error.
path, err := getRootEnv() path, err := getRootEnv(ctx)
if path != "" || err != nil { if path != "" || err != nil {
return path, err return path, err
} }

View File

@ -7,6 +7,7 @@ import (
"testing" "testing"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/env"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -32,49 +33,55 @@ func chdir(t *testing.T, dir string) string {
} }
func TestRootFromEnv(t *testing.T) { func TestRootFromEnv(t *testing.T) {
ctx := context.Background()
dir := t.TempDir() dir := t.TempDir()
t.Setenv(envBundleRoot, dir) t.Setenv(env.RootVariable, dir)
// It should pull the root from the environment variable. // It should pull the root from the environment variable.
root, err := mustGetRoot() root, err := mustGetRoot(ctx)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, root, dir) require.Equal(t, root, dir)
} }
func TestRootFromEnvDoesntExist(t *testing.T) { func TestRootFromEnvDoesntExist(t *testing.T) {
ctx := context.Background()
dir := t.TempDir() dir := t.TempDir()
t.Setenv(envBundleRoot, filepath.Join(dir, "doesntexist")) t.Setenv(env.RootVariable, filepath.Join(dir, "doesntexist"))
// It should pull the root from the environment variable. // It should pull the root from the environment variable.
_, err := mustGetRoot() _, err := mustGetRoot(ctx)
require.Errorf(t, err, "invalid bundle root") require.Errorf(t, err, "invalid bundle root")
} }
func TestRootFromEnvIsFile(t *testing.T) { func TestRootFromEnvIsFile(t *testing.T) {
ctx := context.Background()
dir := t.TempDir() dir := t.TempDir()
f, err := os.Create(filepath.Join(dir, "invalid")) f, err := os.Create(filepath.Join(dir, "invalid"))
require.NoError(t, err) require.NoError(t, err)
f.Close() f.Close()
t.Setenv(envBundleRoot, f.Name()) t.Setenv(env.RootVariable, f.Name())
// It should pull the root from the environment variable. // It should pull the root from the environment variable.
_, err = mustGetRoot() _, err = mustGetRoot(ctx)
require.Errorf(t, err, "invalid bundle root") require.Errorf(t, err, "invalid bundle root")
} }
func TestRootIfEnvIsEmpty(t *testing.T) { func TestRootIfEnvIsEmpty(t *testing.T) {
ctx := context.Background()
dir := "" dir := ""
t.Setenv(envBundleRoot, dir) t.Setenv(env.RootVariable, dir)
// It should pull the root from the environment variable. // It should pull the root from the environment variable.
_, err := mustGetRoot() _, err := mustGetRoot(ctx)
require.Errorf(t, err, "invalid bundle root") require.Errorf(t, err, "invalid bundle root")
} }
func TestRootLookup(t *testing.T) { func TestRootLookup(t *testing.T) {
ctx := context.Background()
// Have to set then unset to allow the testing package to revert it to its original value. // Have to set then unset to allow the testing package to revert it to its original value.
t.Setenv(envBundleRoot, "") t.Setenv(env.RootVariable, "")
os.Unsetenv(envBundleRoot) os.Unsetenv(env.RootVariable)
chdir(t, t.TempDir()) chdir(t, t.TempDir())
@ -89,27 +96,30 @@ func TestRootLookup(t *testing.T) {
// It should find the project root from $PWD. // It should find the project root from $PWD.
wd := chdir(t, "./a/b/c") wd := chdir(t, "./a/b/c")
root, err := mustGetRoot() root, err := mustGetRoot(ctx)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, wd, root) require.Equal(t, wd, root)
} }
func TestRootLookupError(t *testing.T) { func TestRootLookupError(t *testing.T) {
ctx := context.Background()
// Have to set then unset to allow the testing package to revert it to its original value. // Have to set then unset to allow the testing package to revert it to its original value.
t.Setenv(envBundleRoot, "") t.Setenv(env.RootVariable, "")
os.Unsetenv(envBundleRoot) os.Unsetenv(env.RootVariable)
// It can't find a project root from a temporary directory. // It can't find a project root from a temporary directory.
_ = chdir(t, t.TempDir()) _ = chdir(t, t.TempDir())
_, err := mustGetRoot() _, err := mustGetRoot(ctx)
require.ErrorContains(t, err, "unable to locate bundle root") require.ErrorContains(t, err, "unable to locate bundle root")
} }
func TestLoadYamlWhenIncludesEnvPresent(t *testing.T) { func TestLoadYamlWhenIncludesEnvPresent(t *testing.T) {
ctx := context.Background()
chdir(t, filepath.Join(".", "tests", "basic")) chdir(t, filepath.Join(".", "tests", "basic"))
t.Setenv(ExtraIncludePathsKey, "test") t.Setenv(env.IncludesVariable, "test")
bundle, err := MustLoad(context.Background()) bundle, err := MustLoad(ctx)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "basic", bundle.Config.Bundle.Name) assert.Equal(t, "basic", bundle.Config.Bundle.Name)
@ -119,30 +129,33 @@ func TestLoadYamlWhenIncludesEnvPresent(t *testing.T) {
} }
func TestLoadDefautlBundleWhenNoYamlAndRootAndIncludesEnvPresent(t *testing.T) { func TestLoadDefautlBundleWhenNoYamlAndRootAndIncludesEnvPresent(t *testing.T) {
ctx := context.Background()
dir := t.TempDir() dir := t.TempDir()
chdir(t, dir) chdir(t, dir)
t.Setenv(envBundleRoot, dir) t.Setenv(env.RootVariable, dir)
t.Setenv(ExtraIncludePathsKey, "test") t.Setenv(env.IncludesVariable, "test")
bundle, err := MustLoad(context.Background()) bundle, err := MustLoad(ctx)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, dir, bundle.Config.Path) assert.Equal(t, dir, bundle.Config.Path)
} }
func TestErrorIfNoYamlNoRootEnvAndIncludesEnvPresent(t *testing.T) { func TestErrorIfNoYamlNoRootEnvAndIncludesEnvPresent(t *testing.T) {
ctx := context.Background()
dir := t.TempDir() dir := t.TempDir()
chdir(t, dir) chdir(t, dir)
t.Setenv(ExtraIncludePathsKey, "test") t.Setenv(env.IncludesVariable, "test")
_, err := MustLoad(context.Background()) _, err := MustLoad(ctx)
assert.Error(t, err) assert.Error(t, err)
} }
func TestErrorIfNoYamlNoIncludesEnvAndRootEnvPresent(t *testing.T) { func TestErrorIfNoYamlNoIncludesEnvAndRootEnvPresent(t *testing.T) {
ctx := context.Background()
dir := t.TempDir() dir := t.TempDir()
chdir(t, dir) chdir(t, dir)
t.Setenv(envBundleRoot, dir) t.Setenv(env.RootVariable, dir)
_, err := MustLoad(context.Background()) _, err := MustLoad(ctx)
assert.Error(t, err) assert.Error(t, err)
} }

View File

@ -18,12 +18,12 @@ type syncFlags struct {
} }
func (f *syncFlags) syncOptionsFromBundle(cmd *cobra.Command, b *bundle.Bundle) (*sync.SyncOptions, error) { func (f *syncFlags) syncOptionsFromBundle(cmd *cobra.Command, b *bundle.Bundle) (*sync.SyncOptions, error) {
cacheDir, err := b.CacheDir() cacheDir, err := b.CacheDir(cmd.Context())
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot get bundle cache directory: %w", err) return nil, fmt.Errorf("cannot get bundle cache directory: %w", err)
} }
includes, err := b.GetSyncIncludePatterns() includes, err := b.GetSyncIncludePatterns(cmd.Context())
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot get list of sync includes: %w", err) return nil, fmt.Errorf("cannot get list of sync includes: %w", err)
} }

View File

@ -1,6 +1,7 @@
package cmd package cmd
import ( import (
"context"
"strings" "strings"
"github.com/databricks/cli/cmd/account" "github.com/databricks/cli/cmd/account"
@ -21,8 +22,8 @@ const (
permissionsGroup = "permissions" permissionsGroup = "permissions"
) )
func New() *cobra.Command { func New(ctx context.Context) *cobra.Command {
cli := root.New() cli := root.New(ctx)
// Add account subcommand. // Add account subcommand.
cli.AddCommand(account.New()) cli.AddCommand(account.New())

View File

@ -54,7 +54,7 @@ func TestDefaultConfigureNoInteractive(t *testing.T) {
}) })
os.Stdin = inp os.Stdin = inp
cmd := cmd.New() cmd := cmd.New(ctx)
cmd.SetArgs([]string{"configure", "--token", "--host", "https://host"}) cmd.SetArgs([]string{"configure", "--token", "--host", "https://host"})
err := cmd.ExecuteContext(ctx) err := cmd.ExecuteContext(ctx)
@ -87,7 +87,7 @@ func TestConfigFileFromEnvNoInteractive(t *testing.T) {
t.Cleanup(func() { os.Stdin = oldStdin }) t.Cleanup(func() { os.Stdin = oldStdin })
os.Stdin = inp os.Stdin = inp
cmd := cmd.New() cmd := cmd.New(ctx)
cmd.SetArgs([]string{"configure", "--token", "--host", "https://host"}) cmd.SetArgs([]string{"configure", "--token", "--host", "https://host"})
err := cmd.ExecuteContext(ctx) err := cmd.ExecuteContext(ctx)
@ -116,7 +116,7 @@ func TestCustomProfileConfigureNoInteractive(t *testing.T) {
t.Cleanup(func() { os.Stdin = oldStdin }) t.Cleanup(func() { os.Stdin = oldStdin })
os.Stdin = inp os.Stdin = inp
cmd := cmd.New() cmd := cmd.New(ctx)
cmd.SetArgs([]string{"configure", "--token", "--host", "https://host", "--profile", "CUSTOM"}) cmd.SetArgs([]string{"configure", "--token", "--host", "https://host", "--profile", "CUSTOM"})
err := cmd.ExecuteContext(ctx) err := cmd.ExecuteContext(ctx)

View File

@ -2,17 +2,15 @@ package root
import ( import (
"context" "context"
"os"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config/mutator" "github.com/databricks/cli/bundle/config/mutator"
"github.com/databricks/cli/bundle/env"
envlib "github.com/databricks/cli/libs/env"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
) )
const envName = "DATABRICKS_BUNDLE_ENV"
const targetName = "DATABRICKS_BUNDLE_TARGET"
// getTarget returns the name of the target to operate in. // getTarget returns the name of the target to operate in.
func getTarget(cmd *cobra.Command) (value string) { func getTarget(cmd *cobra.Command) (value string) {
// The command line flag takes precedence. // The command line flag takes precedence.
@ -33,13 +31,7 @@ func getTarget(cmd *cobra.Command) (value string) {
} }
// If it's not set, use the environment variable. // If it's not set, use the environment variable.
target := os.Getenv(targetName) target, _ := env.Target(cmd.Context())
// If target env is not set with a new variable, try to check for old variable name
// TODO: remove when environments section is not supported anymore
if target == "" {
target = os.Getenv(envName)
}
return target return target
} }
@ -54,7 +46,7 @@ func getProfile(cmd *cobra.Command) (value string) {
} }
// If it's not set, use the environment variable. // If it's not set, use the environment variable.
return os.Getenv("DATABRICKS_CONFIG_PROFILE") return envlib.Get(cmd.Context(), "DATABRICKS_CONFIG_PROFILE")
} }
// loadBundle loads the bundle configuration and applies default mutators. // loadBundle loads the bundle configuration and applies default mutators.

View File

@ -1,9 +1,8 @@
package root package root
import ( import (
"os"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/flags"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -21,7 +20,7 @@ func initOutputFlag(cmd *cobra.Command) *outputFlag {
// Configure defaults from environment, if applicable. // Configure defaults from environment, if applicable.
// If the provided value is invalid it is ignored. // If the provided value is invalid it is ignored.
if v, ok := os.LookupEnv(envOutputFormat); ok { if v, ok := env.Lookup(cmd.Context(), envOutputFormat); ok {
f.output.Set(v) f.output.Set(v)
} }

View File

@ -5,9 +5,9 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"os"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/flags"
"github.com/databricks/cli/libs/log" "github.com/databricks/cli/libs/log"
"github.com/fatih/color" "github.com/fatih/color"
@ -126,13 +126,13 @@ func initLogFlags(cmd *cobra.Command) *logFlags {
// Configure defaults from environment, if applicable. // Configure defaults from environment, if applicable.
// If the provided value is invalid it is ignored. // If the provided value is invalid it is ignored.
if v, ok := os.LookupEnv(envLogFile); ok { if v, ok := env.Lookup(cmd.Context(), envLogFile); ok {
f.file.Set(v) f.file.Set(v)
} }
if v, ok := os.LookupEnv(envLogLevel); ok { if v, ok := env.Lookup(cmd.Context(), envLogLevel); ok {
f.level.Set(v) f.level.Set(v)
} }
if v, ok := os.LookupEnv(envLogFormat); ok { if v, ok := env.Lookup(cmd.Context(), envLogFormat); ok {
f.output.Set(v) f.output.Set(v)
} }

View File

@ -6,6 +6,7 @@ import (
"os" "os"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/flags"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/term" "golang.org/x/term"
@ -51,7 +52,7 @@ func initProgressLoggerFlag(cmd *cobra.Command, logFlags *logFlags) *progressLog
// Configure defaults from environment, if applicable. // Configure defaults from environment, if applicable.
// If the provided value is invalid it is ignored. // If the provided value is invalid it is ignored.
if v, ok := os.LookupEnv(envProgressFormat); ok { if v, ok := env.Lookup(cmd.Context(), envProgressFormat); ok {
f.Set(v) f.Set(v)
} }

View File

@ -14,7 +14,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
func New() *cobra.Command { func New(ctx context.Context) *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "databricks", Use: "databricks",
Short: "Databricks CLI", Short: "Databricks CLI",
@ -30,6 +30,10 @@ func New() *cobra.Command {
SilenceErrors: true, SilenceErrors: true,
} }
// Pass the context along through the command during initialization.
// It will be overwritten when the command is executed.
cmd.SetContext(ctx)
// Initialize flags // Initialize flags
logFlags := initLogFlags(cmd) logFlags := initLogFlags(cmd)
progressLoggerFlag := initProgressLoggerFlag(cmd, logFlags) progressLoggerFlag := initProgressLoggerFlag(cmd, logFlags)

View File

@ -2,8 +2,8 @@ package root
import ( import (
"context" "context"
"os"
"github.com/databricks/cli/libs/env"
"github.com/databricks/databricks-sdk-go/useragent" "github.com/databricks/databricks-sdk-go/useragent"
) )
@ -16,7 +16,7 @@ const upstreamKey = "upstream"
const upstreamVersionKey = "upstream-version" const upstreamVersionKey = "upstream-version"
func withUpstreamInUserAgent(ctx context.Context) context.Context { func withUpstreamInUserAgent(ctx context.Context) context.Context {
value := os.Getenv(upstreamEnvVar) value := env.Get(ctx, upstreamEnvVar)
if value == "" { if value == "" {
return ctx return ctx
} }
@ -24,7 +24,7 @@ func withUpstreamInUserAgent(ctx context.Context) context.Context {
ctx = useragent.InContext(ctx, upstreamKey, value) ctx = useragent.InContext(ctx, upstreamKey, value)
// Include upstream version as well, if set. // Include upstream version as well, if set.
value = os.Getenv(upstreamVersionEnvVar) value = env.Get(ctx, upstreamVersionEnvVar)
if value == "" { if value == "" {
return ctx return ctx
} }

View File

@ -30,12 +30,12 @@ func (f *syncFlags) syncOptionsFromBundle(cmd *cobra.Command, args []string, b *
return nil, fmt.Errorf("SRC and DST are not configurable in the context of a bundle") return nil, fmt.Errorf("SRC and DST are not configurable in the context of a bundle")
} }
cacheDir, err := b.CacheDir() cacheDir, err := b.CacheDir(cmd.Context())
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot get bundle cache directory: %w", err) return nil, fmt.Errorf("cannot get bundle cache directory: %w", err)
} }
includes, err := b.GetSyncIncludePatterns() includes, err := b.GetSyncIncludePatterns(cmd.Context())
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot get list of sync includes: %w", err) return nil, fmt.Errorf("cannot get list of sync includes: %w", err)
} }

View File

@ -118,7 +118,7 @@ func (t *cobraTestRunner) RunBackground() {
var stdoutW, stderrW io.WriteCloser var stdoutW, stderrW io.WriteCloser
stdoutR, stdoutW = io.Pipe() stdoutR, stdoutW = io.Pipe()
stderrR, stderrW = io.Pipe() stderrR, stderrW = io.Pipe()
root := cmd.New() root := cmd.New(context.Background())
root.SetOut(stdoutW) root.SetOut(stdoutW)
root.SetErr(stderrW) root.SetErr(stderrW)
root.SetArgs(t.args) root.SetArgs(t.args)

33
internal/testutil/env.go Normal file
View File

@ -0,0 +1,33 @@
package testutil
import (
"os"
"strings"
"testing"
)
// CleanupEnvironment sets up a pristine environment containing only $PATH and $HOME.
// The original environment is restored upon test completion.
// Note: use of this function is incompatible with parallel execution.
func CleanupEnvironment(t *testing.T) {
// Restore environment when test finishes.
environ := os.Environ()
t.Cleanup(func() {
// Restore original environment.
for _, kv := range environ {
kvs := strings.SplitN(kv, "=", 2)
os.Setenv(kvs[0], kvs[1])
}
})
path := os.Getenv("PATH")
pwd := os.Getenv("PWD")
os.Clearenv()
// We use t.Setenv instead of os.Setenv because the former actively
// prevents a test being run with t.Parallel. Modifying the environment
// within a test is not compatible with running tests in parallel
// because of isolation; the environment is scoped to the process.
t.Setenv("PATH", path)
t.Setenv("HOME", pwd)
}

63
libs/env/context.go vendored Normal file
View File

@ -0,0 +1,63 @@
package env
import (
"context"
"os"
)
var envContextKey int
func copyMap(m map[string]string) map[string]string {
out := make(map[string]string, len(m))
for k, v := range m {
out[k] = v
}
return out
}
func getMap(ctx context.Context) map[string]string {
if ctx == nil {
return nil
}
m, ok := ctx.Value(&envContextKey).(map[string]string)
if !ok {
return nil
}
return m
}
func setMap(ctx context.Context, m map[string]string) context.Context {
return context.WithValue(ctx, &envContextKey, m)
}
// Lookup key in the context or the the environment.
// Context has precedence.
func Lookup(ctx context.Context, key string) (string, bool) {
m := getMap(ctx)
// Return if the key is set in the context.
v, ok := m[key]
if ok {
return v, true
}
// Fall back to the environment.
return os.LookupEnv(key)
}
// Get key from the context or the environment.
// Context has precedence.
func Get(ctx context.Context, key string) string {
v, _ := Lookup(ctx, key)
return v
}
// Set key on the context.
//
// Note: this does NOT mutate the processes' actual environment variables.
// It is only visible to other code that uses this package.
func Set(ctx context.Context, key, value string) context.Context {
m := copyMap(getMap(ctx))
m[key] = value
return setMap(ctx, m)
}

41
libs/env/context_test.go vendored Normal file
View File

@ -0,0 +1,41 @@
package env
import (
"context"
"testing"
"github.com/databricks/cli/internal/testutil"
"github.com/stretchr/testify/assert"
)
func TestContext(t *testing.T) {
testutil.CleanupEnvironment(t)
t.Setenv("FOO", "bar")
ctx0 := context.Background()
// Get
assert.Equal(t, "bar", Get(ctx0, "FOO"))
assert.Equal(t, "", Get(ctx0, "dontexist"))
// Lookup
v, ok := Lookup(ctx0, "FOO")
assert.True(t, ok)
assert.Equal(t, "bar", v)
v, ok = Lookup(ctx0, "dontexist")
assert.False(t, ok)
assert.Equal(t, "", v)
// Set and get new context.
// Verify that the previous context remains unchanged.
ctx1 := Set(ctx0, "FOO", "baz")
assert.Equal(t, "baz", Get(ctx1, "FOO"))
assert.Equal(t, "bar", Get(ctx0, "FOO"))
// Set and get new context.
// Verify that the previous contexts remains unchanged.
ctx2 := Set(ctx1, "FOO", "qux")
assert.Equal(t, "qux", Get(ctx2, "FOO"))
assert.Equal(t, "baz", Get(ctx1, "FOO"))
assert.Equal(t, "bar", Get(ctx0, "FOO"))
}

7
libs/env/pkg.go vendored Normal file
View File

@ -0,0 +1,7 @@
package env
// The env package provides functions for working with environment variables
// and allowing for overrides via the context.Context. This is useful for
// testing where tainting a processes' environment is at odds with parallelism.
// Use of a context.Context to store variable overrides means tests can be
// parallelized without worrying about environment variable interference.

View File

@ -1,10 +1,12 @@
package main package main
import ( import (
"context"
"github.com/databricks/cli/cmd" "github.com/databricks/cli/cmd"
"github.com/databricks/cli/cmd/root" "github.com/databricks/cli/cmd/root"
) )
func main() { func main() {
root.Execute(cmd.New()) root.Execute(cmd.New(context.Background()))
} }

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"testing" "testing"
"github.com/databricks/cli/cmd" "github.com/databricks/cli/cmd"
@ -15,7 +16,7 @@ func TestCommandsDontUseUnderscoreInName(t *testing.T) {
// This test lives in the main package because this is where // This test lives in the main package because this is where
// all commands are imported. // all commands are imported.
// //
queue := []*cobra.Command{cmd.New()} queue := []*cobra.Command{cmd.New(context.Background())}
for len(queue) > 0 { for len(queue) > 0 {
cmd := queue[0] cmd := queue[0]
assert.NotContains(t, cmd.Name(), "_") assert.NotContains(t, cmd.Name(), "_")