Added `env.UserHomeDir(ctx)` for parallel-friendly tests (#955)

## Changes
`os.Getenv(..)` is not friendly with `libs/env`. This PR makes the
relevant changes to places where we need to read user home directory.

## Tests
Mainly done in https://github.com/databricks/cli/pull/914
This commit is contained in:
Serge Smertin 2023-11-08 15:50:20 +01:00 committed by GitHub
parent 7847388f95
commit e68a88e14d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 190 additions and 56 deletions

View File

@ -1,6 +1,7 @@
package auth package auth
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -68,8 +69,8 @@ func resolveSection(cfg *config.Config, iniFile *config.File) (*ini.Section, err
return candidates[0], nil return candidates[0], nil
} }
func loadFromDatabricksCfg(cfg *config.Config) error { func loadFromDatabricksCfg(ctx context.Context, cfg *config.Config) error {
iniFile, err := databrickscfg.Get() iniFile, err := databrickscfg.Get(ctx)
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {
// it's fine not to have ~/.databrickscfg // it's fine not to have ~/.databrickscfg
return nil return nil
@ -110,7 +111,7 @@ func newEnvCommand() *cobra.Command {
cfg.Profile = profile cfg.Profile = profile
} else if cfg.Host == "" { } else if cfg.Host == "" {
cfg.Profile = "DEFAULT" cfg.Profile = "DEFAULT"
} else if err := loadFromDatabricksCfg(cfg); err != nil { } else if err := loadFromDatabricksCfg(cmd.Context(), cfg); err != nil {
return err return err
} }
// Go SDK is lazy loaded because of Terraform semantics, // Go SDK is lazy loaded because of Terraform semantics,

View File

@ -128,7 +128,7 @@ func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
func setHost(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error { func setHost(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error {
// If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile. // If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile.
_, profiles, err := databrickscfg.LoadProfiles(func(p databrickscfg.Profile) bool { _, profiles, err := databrickscfg.LoadProfiles(ctx, func(p databrickscfg.Profile) bool {
return p.Name == profileName return p.Name == profileName
}) })
if err != nil { if err != nil {

View File

@ -95,7 +95,7 @@ func newProfilesCommand() *cobra.Command {
cmd.RunE = func(cmd *cobra.Command, args []string) error { cmd.RunE = func(cmd *cobra.Command, args []string) error {
var profiles []*profileMetadata var profiles []*profileMetadata
iniFile, err := databrickscfg.Get() iniFile, err := databrickscfg.Get(cmd.Context())
if os.IsNotExist(err) { if os.IsNotExist(err) {
// return empty list for non-configured machines // return empty list for non-configured machines
iniFile = &config.File{ iniFile = &config.File{

View File

@ -5,7 +5,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"os"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
@ -55,7 +54,7 @@ func accountClientOrPrompt(ctx context.Context, cfg *config.Config, allowPrompt
} }
// Try picking a profile dynamically if the current configuration is not valid. // Try picking a profile dynamically if the current configuration is not valid.
profile, err := askForAccountProfile(ctx) profile, err := AskForAccountProfile(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -83,7 +82,7 @@ func MustAccountClient(cmd *cobra.Command, args []string) error {
// 1. only admins will have account configured // 1. only admins will have account configured
// 2. 99% of admins will have access to just one account // 2. 99% of admins will have access to just one account
// hence, we don't need to create a special "DEFAULT_ACCOUNT" profile yet // hence, we don't need to create a special "DEFAULT_ACCOUNT" profile yet
_, profiles, err := databrickscfg.LoadProfiles(databrickscfg.MatchAccountProfiles) _, profiles, err := databrickscfg.LoadProfiles(cmd.Context(), databrickscfg.MatchAccountProfiles)
if err != nil { if err != nil {
return err return err
} }
@ -123,7 +122,7 @@ func workspaceClientOrPrompt(ctx context.Context, cfg *config.Config, allowPromp
} }
// Try picking a profile dynamically if the current configuration is not valid. // Try picking a profile dynamically if the current configuration is not valid.
profile, err := askForWorkspaceProfile(ctx) profile, err := AskForWorkspaceProfile(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -173,21 +172,14 @@ func SetWorkspaceClient(ctx context.Context, w *databricks.WorkspaceClient) cont
return context.WithValue(ctx, &workspaceClient, w) return context.WithValue(ctx, &workspaceClient, w)
} }
func transformLoadError(path string, err error) error { func AskForWorkspaceProfile(ctx context.Context) (string, error) {
if os.IsNotExist(err) { path, err := databrickscfg.GetPath(ctx)
return fmt.Errorf("no configuration file found at %s; please create one first", path)
}
return err
}
func askForWorkspaceProfile(ctx context.Context) (string, error) {
path, err := databrickscfg.GetPath()
if err != nil { if err != nil {
return "", fmt.Errorf("cannot determine Databricks config file path: %w", err) return "", fmt.Errorf("cannot determine Databricks config file path: %w", err)
} }
file, profiles, err := databrickscfg.LoadProfiles(databrickscfg.MatchWorkspaceProfiles) file, profiles, err := databrickscfg.LoadProfiles(ctx, databrickscfg.MatchWorkspaceProfiles)
if err != nil { if err != nil {
return "", transformLoadError(path, err) return "", err
} }
switch len(profiles) { switch len(profiles) {
case 0: case 0:
@ -213,14 +205,14 @@ func askForWorkspaceProfile(ctx context.Context) (string, error) {
return profiles[i].Name, nil return profiles[i].Name, nil
} }
func askForAccountProfile(ctx context.Context) (string, error) { func AskForAccountProfile(ctx context.Context) (string, error) {
path, err := databrickscfg.GetPath() path, err := databrickscfg.GetPath(ctx)
if err != nil { if err != nil {
return "", fmt.Errorf("cannot determine Databricks config file path: %w", err) return "", fmt.Errorf("cannot determine Databricks config file path: %w", err)
} }
file, profiles, err := databrickscfg.LoadProfiles(databrickscfg.MatchAccountProfiles) file, profiles, err := databrickscfg.LoadProfiles(ctx, databrickscfg.MatchAccountProfiles)
if err != nil { if err != nil {
return "", transformLoadError(path, err) return "", err
} }
switch len(profiles) { switch len(profiles) {
case 0: case 0:

View File

@ -1,11 +1,14 @@
package databrickscfg package databrickscfg
import ( import (
"context"
"errors"
"fmt" "fmt"
"os" "io/fs"
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/databricks/cli/libs/env"
"github.com/databricks/databricks-sdk-go/config" "github.com/databricks/databricks-sdk-go/config"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -67,43 +70,45 @@ func MatchAllProfiles(p Profile) bool {
} }
// Get the path to the .databrickscfg file, falling back to the default in the current user's home directory. // Get the path to the .databrickscfg file, falling back to the default in the current user's home directory.
func GetPath() (string, error) { func GetPath(ctx context.Context) (string, error) {
configFile := os.Getenv("DATABRICKS_CONFIG_FILE") configFile := env.Get(ctx, "DATABRICKS_CONFIG_FILE")
if configFile == "" { if configFile == "" {
configFile = "~/.databrickscfg" configFile = "~/.databrickscfg"
} }
if strings.HasPrefix(configFile, "~") { if strings.HasPrefix(configFile, "~") {
homedir, err := os.UserHomeDir() homedir := env.UserHomeDir(ctx)
if err != nil {
return "", fmt.Errorf("cannot find homedir: %w", err)
}
configFile = filepath.Join(homedir, configFile[1:]) configFile = filepath.Join(homedir, configFile[1:])
} }
return configFile, nil return configFile, nil
} }
func Get() (*config.File, error) { var ErrNoConfiguration = errors.New("no configuration file found")
configFile, err := GetPath()
func Get(ctx context.Context) (*config.File, error) {
path, err := GetPath(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot determine Databricks config file path: %w", err) return nil, fmt.Errorf("cannot determine Databricks config file path: %w", err)
} }
return config.LoadFile(configFile) configFile, err := config.LoadFile(path)
if errors.Is(err, fs.ErrNotExist) {
// downstreams depend on ErrNoConfiguration. TODO: expose this error through SDK
return nil, fmt.Errorf("%w at %s; please create one first", ErrNoConfiguration, path)
} else if err != nil {
return nil, err
}
return configFile, nil
} }
func LoadProfiles(fn ProfileMatchFunction) (file string, profiles Profiles, err error) { func LoadProfiles(ctx context.Context, fn ProfileMatchFunction) (file string, profiles Profiles, err error) {
f, err := Get() f, err := Get(ctx)
if err != nil { if err != nil {
return "", nil, fmt.Errorf("cannot load Databricks config file: %w", err) return "", nil, fmt.Errorf("cannot load Databricks config file: %w", err)
} }
homedir, err := os.UserHomeDir()
if err != nil {
return
}
// Replace homedir with ~ if applicable. // Replace homedir with ~ if applicable.
// This is to make the output more readable. // This is to make the output more readable.
file = f.Path() file = filepath.Clean(f.Path())
homedir := filepath.Clean(env.UserHomeDir(ctx))
if strings.HasPrefix(file, homedir) { if strings.HasPrefix(file, homedir) {
file = "~" + file[len(homedir):] file = "~" + file[len(homedir):]
} }
@ -130,7 +135,7 @@ func LoadProfiles(fn ProfileMatchFunction) (file string, profiles Profiles, err
} }
func ProfileCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { func ProfileCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
_, profiles, err := LoadProfiles(MatchAllProfiles) _, profiles, err := LoadProfiles(cmd.Context(), MatchAllProfiles)
if err != nil { if err != nil {
return nil, cobra.ShellCompDirectiveError return nil, cobra.ShellCompDirectiveError
} }

View File

@ -1,9 +1,11 @@
package databrickscfg package databrickscfg
import ( import (
"runtime" "context"
"path/filepath"
"testing" "testing"
"github.com/databricks/cli/libs/env"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -27,27 +29,50 @@ func TestProfilesSearchCaseInsensitive(t *testing.T) {
} }
func TestLoadProfilesReturnsHomedirAsTilde(t *testing.T) { func TestLoadProfilesReturnsHomedirAsTilde(t *testing.T) {
if runtime.GOOS == "windows" { ctx := context.Background()
t.Setenv("USERPROFILE", "./testdata") ctx = env.WithUserHomeDir(ctx, "testdata")
} else { ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg")
t.Setenv("HOME", "./testdata") file, _, err := LoadProfiles(ctx, func(p Profile) bool { return true })
}
t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg")
file, _, err := LoadProfiles(func(p Profile) bool { return true })
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "~/databrickscfg", file) require.Equal(t, filepath.Clean("~/databrickscfg"), file)
}
func TestLoadProfilesReturnsHomedirAsTildeExoticFile(t *testing.T) {
ctx := context.Background()
ctx = env.WithUserHomeDir(ctx, "testdata")
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "~/databrickscfg")
file, _, err := LoadProfiles(ctx, func(p Profile) bool { return true })
require.NoError(t, err)
require.Equal(t, filepath.Clean("~/databrickscfg"), file)
}
func TestLoadProfilesReturnsHomedirAsTildeDefaultFile(t *testing.T) {
ctx := context.Background()
ctx = env.WithUserHomeDir(ctx, "testdata/sample-home")
file, _, err := LoadProfiles(ctx, func(p Profile) bool { return true })
require.NoError(t, err)
require.Equal(t, filepath.Clean("~/.databrickscfg"), file)
}
func TestLoadProfilesNoConfiguration(t *testing.T) {
ctx := context.Background()
ctx = env.WithUserHomeDir(ctx, "testdata")
_, _, err := LoadProfiles(ctx, func(p Profile) bool { return true })
require.ErrorIs(t, err, ErrNoConfiguration)
} }
func TestLoadProfilesMatchWorkspace(t *testing.T) { func TestLoadProfilesMatchWorkspace(t *testing.T) {
t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg") ctx := context.Background()
_, profiles, err := LoadProfiles(MatchWorkspaceProfiles) ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg")
_, profiles, err := LoadProfiles(ctx, MatchWorkspaceProfiles)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"DEFAULT", "query", "foo1", "foo2"}, profiles.Names()) assert.Equal(t, []string{"DEFAULT", "query", "foo1", "foo2"}, profiles.Names())
} }
func TestLoadProfilesMatchAccount(t *testing.T) { func TestLoadProfilesMatchAccount(t *testing.T) {
t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg") ctx := context.Background()
_, profiles, err := LoadProfiles(MatchAccountProfiles) ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg")
_, profiles, err := LoadProfiles(ctx, MatchAccountProfiles)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"acc"}, profiles.Names()) assert.Equal(t, []string{"acc"}, profiles.Names())
} }

View File

@ -0,0 +1,7 @@
[DEFAULT]
host = https://default
token = default
[acc]
host = https://accounts.cloud.databricks.com
account_id = abc

21
libs/env/context.go vendored
View File

@ -2,7 +2,9 @@ package env
import ( import (
"context" "context"
"fmt"
"os" "os"
"runtime"
"strings" "strings"
) )
@ -63,6 +65,25 @@ func Set(ctx context.Context, key, value string) context.Context {
return setMap(ctx, m) return setMap(ctx, m)
} }
func homeEnvVar() string {
if runtime.GOOS == "windows" {
return "USERPROFILE"
}
return "HOME"
}
func WithUserHomeDir(ctx context.Context, value string) context.Context {
return Set(ctx, homeEnvVar(), value)
}
func UserHomeDir(ctx context.Context) string {
home := Get(ctx, homeEnvVar())
if home == "" {
panic(fmt.Errorf("$HOME is not set"))
}
return home
}
// All returns environment variables that are defined in both os.Environ // All returns environment variables that are defined in both os.Environ
// and this package. `env.Set(ctx, x, y)` will override x from os.Environ. // and this package. `env.Set(ctx, x, y)` will override x from os.Environ.
func All(ctx context.Context) map[string]string { func All(ctx context.Context) map[string]string {

View File

@ -47,3 +47,10 @@ func TestContext(t *testing.T) {
assert.Equal(t, "x=y", all["BAR"]) assert.Equal(t, "x=y", all["BAR"])
assert.NotEmpty(t, all["PATH"]) assert.NotEmpty(t, all["PATH"])
} }
func TestHome(t *testing.T) {
ctx := context.Background()
ctx = WithUserHomeDir(ctx, "...")
home := UserHomeDir(ctx)
assert.Equal(t, "...", home)
}

50
libs/env/loader.go vendored Normal file
View File

@ -0,0 +1,50 @@
package env
import (
"context"
"github.com/databricks/databricks-sdk-go/config"
)
// NewConfigLoader creates Databricks SDK Config loader that is aware of env.Set variables:
//
// ctx = env.Set(ctx, "DATABRICKS_WAREHOUSE_ID", "...")
//
// Usage:
//
// &config.Config{
// Loaders: []config.Loader{
// env.NewConfigLoader(ctx),
// config.ConfigAttributes,
// config.ConfigFile,
// },
// }
func NewConfigLoader(ctx context.Context) *configLoader {
return &configLoader{
ctx: ctx,
}
}
type configLoader struct {
ctx context.Context
}
func (le *configLoader) Name() string {
return "cli-env"
}
func (le *configLoader) Configure(cfg *config.Config) error {
for _, a := range config.ConfigAttributes {
if !a.IsZero(cfg) {
continue
}
for _, k := range a.EnvVars {
v := Get(le.ctx, k)
if v == "" {
continue
}
a.Set(cfg, v)
}
}
return nil
}

26
libs/env/loader_test.go vendored Normal file
View File

@ -0,0 +1,26 @@
package env
import (
"context"
"testing"
"github.com/databricks/databricks-sdk-go/config"
"github.com/stretchr/testify/assert"
)
func TestLoader(t *testing.T) {
ctx := context.Background()
ctx = Set(ctx, "DATABRICKS_WAREHOUSE_ID", "...")
ctx = Set(ctx, "DATABRICKS_CONFIG_PROFILE", "...")
loader := NewConfigLoader(ctx)
cfg := &config.Config{
Profile: "abc",
}
err := loader.Configure(cfg)
assert.NoError(t, err)
assert.Equal(t, "...", cfg.WarehouseID)
assert.Equal(t, "abc", cfg.Profile)
assert.Equal(t, "cli-env", loader.Name())
}