diff --git a/cmd/root/auth.go b/cmd/root/auth.go index 39f7bf22..61068ab3 100644 --- a/cmd/root/auth.go +++ b/cmd/root/auth.go @@ -5,17 +5,15 @@ import ( "errors" "fmt" "os" - "path/filepath" - "strings" "github.com/databricks/cli/bundle" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/databrickscfg" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/config" "github.com/databricks/databricks-sdk-go/service/iam" "github.com/manifoldco/promptui" "github.com/spf13/cobra" - "gopkg.in/ini.v1" ) // Placeholders to use as unique keys in context.Context. @@ -41,19 +39,15 @@ func MustAccountClient(cmd *cobra.Command, args []string) error { // 1. only admins will have account configured // 2. 99% of admins will have access to just one account // hence, we don't need to create a special "DEFAULT_ACCOUNT" profile yet - profiles, err := loadProfiles() + _, profiles, err := databrickscfg.LoadProfiles( + databrickscfg.DefaultPath, + databrickscfg.MatchAccountProfiles, + ) if err != nil { return err } - var items []profile - for _, v := range profiles { - if v.AccountID == "" { - continue - } - items = append(items, v) - } - if len(items) == 1 { - cfg.Profile = items[0].Name + if len(profiles) == 1 { + cfg.Profile = profiles[0].Name } } @@ -121,107 +115,75 @@ TRY_AUTH: // or try picking a config profile dynamically return nil } -type profile struct { - Name string - Host string - AccountID string -} - -func (p profile) Cloud() string { - if strings.Contains(p.Host, ".azuredatabricks.net") { - return "Azure" +func transformLoadError(path string, err error) error { + if os.IsNotExist(err) { + return fmt.Errorf("no configuration file found at %s; please create one first", path) } - if strings.Contains(p.Host, "gcp.databricks.com") { - return "GCP" - } - return "AWS" -} - -func loadProfiles() (profiles []profile, err error) { - homedir, err := os.UserHomeDir() - if err != nil { - return nil, fmt.Errorf("cannot find homedir: %w", err) - } - file := filepath.Join(homedir, ".databrickscfg") - iniFile, err := ini.Load(file) - if err != nil { - return - } - for _, v := range iniFile.Sections() { - all := v.KeysHash() - host, ok := all["host"] - if !ok { - // invalid profile - continue - } - profiles = append(profiles, profile{ - Name: v.Name(), - Host: host, - AccountID: all["account_id"], - }) - } - return profiles, nil + return err } func askForWorkspaceProfile() (string, error) { - profiles, err := loadProfiles() + path := databrickscfg.DefaultPath + file, profiles, err := databrickscfg.LoadProfiles(path, databrickscfg.MatchWorkspaceProfiles) if err != nil { - return "", err + return "", transformLoadError(path, err) } - var items []profile - for _, v := range profiles { - if v.AccountID != "" { - continue - } - items = append(items, v) + switch len(profiles) { + case 0: + return "", fmt.Errorf("%s does not contain workspace profiles; please create one first", path) + case 1: + return profiles[0].Name, nil } - label := "~/.databrickscfg profile" i, _, err := (&promptui.Select{ - Label: label, - Items: items, + Label: fmt.Sprintf("Workspace profiles defined in %s", file), + Items: profiles, + Searcher: profiles.SearchCaseInsensitive, + StartInSearchMode: true, Templates: &promptui.SelectTemplates{ + Label: "{{ . | faint }}", Active: `{{.Name | bold}} ({{.Host|faint}})`, Inactive: `{{.Name}}`, - Selected: fmt.Sprintf(`{{ "%s" | faint }}: {{ .Name | bold }}`, label), + Selected: `{{ "Using workspace profile" | faint }}: {{ .Name | bold }}`, }, - Stdin: os.Stdin, + Stdin: os.Stdin, + Stdout: os.Stderr, }).Run() if err != nil { return "", err } - return items[i].Name, nil + return profiles[i].Name, nil } func askForAccountProfile() (string, error) { - profiles, err := loadProfiles() + path := databrickscfg.DefaultPath + file, profiles, err := databrickscfg.LoadProfiles(path, databrickscfg.MatchAccountProfiles) if err != nil { - return "", err + return "", transformLoadError(path, err) } - var items []profile - for _, v := range profiles { - if v.AccountID == "" { - continue - } - items = append(items, v) + switch len(profiles) { + case 0: + return "", fmt.Errorf("%s does not contain account profiles; please create one first", path) + case 1: + return profiles[0].Name, nil } - if len(items) == 1 { - return items[0].Name, nil - } - label := "~/.databrickscfg profile" i, _, err := (&promptui.Select{ - Label: label, - Items: items, + Label: fmt.Sprintf("Account profiles defined in %s", file), + Items: profiles, + Searcher: profiles.SearchCaseInsensitive, + StartInSearchMode: true, Templates: &promptui.SelectTemplates{ + Label: "{{ . | faint }}", Active: `{{.Name | bold}} ({{.AccountID|faint}} {{.Cloud|faint}})`, Inactive: `{{.Name}}`, - Selected: fmt.Sprintf(`{{ "%s" | faint }}: {{ .Name | bold }}`, label), + Selected: `{{ "Using account profile" | faint }}: {{ .Name | bold }}`, }, - Stdin: os.Stdin, + Stdin: os.Stdin, + Stdout: os.Stderr, }).Run() if err != nil { return "", err } - return items[i].Name, nil + return profiles[i].Name, nil } func WorkspaceClient(ctx context.Context) *databricks.WorkspaceClient { diff --git a/libs/cmdio/io.go b/libs/cmdio/io.go index 762a9455..327f3013 100644 --- a/libs/cmdio/io.go +++ b/libs/cmdio/io.go @@ -10,7 +10,6 @@ import ( "github.com/briandowns/spinner" "github.com/databricks/cli/libs/flags" - "github.com/fatih/color" "github.com/manifoldco/promptui" "github.com/mattn/go-isatty" "golang.org/x/exp/slices" @@ -31,8 +30,13 @@ type cmdIO struct { } func NewIO(outputFormat flags.Output, in io.Reader, out io.Writer, err io.Writer, template string) *cmdIO { + // The check below is similar to color.NoColor but uses the specified err writer. + dumb := os.Getenv("NO_COLOR") != "" || os.Getenv("TERM") == "dumb" + if f, ok := err.(*os.File); ok && !dumb { + dumb = !isatty.IsTerminal(f.Fd()) && !isatty.IsCygwinTerminal(f.Fd()) + } return &cmdIO{ - interactive: !color.NoColor, + interactive: !dumb, outputFormat: outputFormat, template: template, in: in, diff --git a/libs/databrickscfg/profiles.go b/libs/databrickscfg/profiles.go new file mode 100644 index 00000000..60b2a89a --- /dev/null +++ b/libs/databrickscfg/profiles.go @@ -0,0 +1,101 @@ +package databrickscfg + +import ( + "os" + "strings" + + "github.com/databricks/databricks-sdk-go/config" +) + +// Profile holds a subset of the keys in a databrickscfg profile. +// It should only be used for prompting and filtering. +// Use its name to construct a config.Config. +type Profile struct { + Name string + Host string + AccountID string +} + +func (p Profile) Cloud() string { + cfg := config.Config{Host: p.Host} + switch { + case cfg.IsAws(): + return "AWS" + case cfg.IsAzure(): + return "Azure" + case cfg.IsGcp(): + return "GCP" + default: + return "" + } +} + +type Profiles []Profile + +func (p Profiles) Names() []string { + names := make([]string, len(p)) + for i, v := range p { + names[i] = v.Name + } + return names +} + +// SearchCaseInsensitive implements the promptui.Searcher interface. +// This allows the user to immediately starting typing to narrow down the list. +func (p Profiles) SearchCaseInsensitive(input string, index int) bool { + input = strings.ToLower(input) + name := strings.ToLower(p[index].Name) + host := strings.ToLower(p[index].Host) + return strings.Contains(name, input) || strings.Contains(host, input) +} + +type ProfileMatchFunction func(Profile) bool + +func MatchWorkspaceProfiles(p Profile) bool { + return p.AccountID == "" +} + +func MatchAccountProfiles(p Profile) bool { + return p.Host != "" && p.AccountID != "" +} + +const DefaultPath = "~/.databrickscfg" + +func LoadProfiles(path string, fn ProfileMatchFunction) (file string, profiles Profiles, err error) { + f, err := config.LoadFile(path) + if err != nil { + return + } + + homedir, err := os.UserHomeDir() + if err != nil { + return + } + + // Replace homedir with ~ if applicable. + // This is to make the output more readable. + file = f.Path() + if strings.HasPrefix(file, homedir) { + file = "~" + file[len(homedir):] + } + + // Iterate over sections and collect matching profiles. + for _, v := range f.Sections() { + all := v.KeysHash() + host, ok := all["host"] + if !ok { + // invalid profile + continue + } + profile := Profile{ + Name: v.Name(), + Host: host, + AccountID: all["account_id"], + } + if fn(profile) { + profiles = append(profiles, profile) + } + } + + return +} diff --git a/libs/databrickscfg/profiles_test.go b/libs/databrickscfg/profiles_test.go new file mode 100644 index 00000000..582c6658 --- /dev/null +++ b/libs/databrickscfg/profiles_test.go @@ -0,0 +1,50 @@ +package databrickscfg + +import ( + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestProfileCloud(t *testing.T) { + assert.Equal(t, Profile{Host: "https://dbc-XXXXXXXX-YYYY.cloud.databricks.com"}.Cloud(), "AWS") + assert.Equal(t, Profile{Host: "https://adb-xxx.y.azuredatabricks.net/"}.Cloud(), "Azure") + assert.Equal(t, Profile{Host: "https://workspace.gcp.databricks.com/"}.Cloud(), "GCP") + assert.Equal(t, Profile{Host: "https://some.invalid.host.com/"}.Cloud(), "AWS") +} + +func TestProfilesSearchCaseInsensitive(t *testing.T) { + profiles := Profiles{ + Profile{Name: "foo", Host: "bar"}, + } + assert.True(t, profiles.SearchCaseInsensitive("f", 0)) + assert.True(t, profiles.SearchCaseInsensitive("OO", 0)) + assert.True(t, profiles.SearchCaseInsensitive("b", 0)) + assert.True(t, profiles.SearchCaseInsensitive("AR", 0)) + assert.False(t, profiles.SearchCaseInsensitive("qu", 0)) +} + +func TestLoadProfilesReturnsHomedirAsTilde(t *testing.T) { + if runtime.GOOS == "windows" { + t.Setenv("USERPROFILE", "./testdata") + } else { + t.Setenv("HOME", "./testdata") + } + file, _, err := LoadProfiles("./testdata/databrickscfg", func(p Profile) bool { return true }) + require.NoError(t, err) + assert.Equal(t, "~/databrickscfg", file) +} + +func TestLoadProfilesMatchWorkspace(t *testing.T) { + _, profiles, err := LoadProfiles("./testdata/databrickscfg", MatchWorkspaceProfiles) + require.NoError(t, err) + assert.Equal(t, []string{"DEFAULT", "query", "foo1", "foo2"}, profiles.Names()) +} + +func TestLoadProfilesMatchAccount(t *testing.T) { + _, profiles, err := LoadProfiles("./testdata/databrickscfg", MatchAccountProfiles) + require.NoError(t, err) + assert.Equal(t, []string{"acc"}, profiles.Names()) +}