diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 0e1f0125..14a875b1 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -2,9 +2,13 @@ package auth import ( "context" + "fmt" "time" "github.com/databricks/cli/libs/auth" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/databrickscfg" + "github.com/databricks/databricks-sdk-go/config" "github.com/spf13/cobra" ) @@ -20,7 +24,40 @@ var loginCmd = &cobra.Command{ defer perisistentAuth.Close() ctx, cancel := context.WithTimeout(cmd.Context(), loginTimeout) defer cancel() - return perisistentAuth.Challenge(ctx) + + var profileName string + profileFlag := cmd.Flag("profile") + if profileFlag != nil && profileFlag.Value.String() != "" { + profileName = profileFlag.Value.String() + } else { + prompt := cmdio.Prompt(ctx) + prompt.Label = "Databricks Profile Name" + prompt.Default = perisistentAuth.ProfileName() + prompt.AllowEdit = true + profile, err := prompt.Run() + if err != nil { + return err + } + profileName = profile + } + err := perisistentAuth.Challenge(ctx) + if err != nil { + return err + } + + err = databrickscfg.SaveToProfile(ctx, &config.Config{ + Host: perisistentAuth.Host, + AccountID: perisistentAuth.AccountID, + AuthType: "databricks-cli", + Profile: profileName, + }) + + if err != nil { + return err + } + + cmdio.LogString(ctx, fmt.Sprintf("Profile %s was successfully saved", profileName)) + return nil }, } diff --git a/libs/auth/oauth.go b/libs/auth/oauth.go index 1171103c..b7e0ce2f 100644 --- a/libs/auth/oauth.go +++ b/libs/auth/oauth.go @@ -16,8 +16,6 @@ import ( "time" "github.com/databricks/cli/libs/auth/cache" - "github.com/databricks/cli/libs/databrickscfg" - "github.com/databricks/databricks-sdk-go/config" "github.com/databricks/databricks-sdk-go/retries" "github.com/pkg/browser" "golang.org/x/oauth2" @@ -97,7 +95,7 @@ func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) { return refreshed, nil } -func (a *PersistentAuth) profileName() string { +func (a *PersistentAuth) ProfileName() string { // TODO: get profile name from interactive input if a.AccountID != "" { return fmt.Sprintf("ACCOUNT-%s", a.AccountID) @@ -132,12 +130,7 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error { if err != nil { return fmt.Errorf("store: %w", err) } - return databrickscfg.SaveToProfile(ctx, &config.Config{ - Host: a.Host, - AccountID: a.AccountID, - AuthType: "databricks-cli", - Profile: a.profileName(), - }) + return nil } func (a *PersistentAuth) init(ctx context.Context) error { diff --git a/libs/databrickscfg/ops.go b/libs/databrickscfg/ops.go index 52a966ef..4a4a27b0 100644 --- a/libs/databrickscfg/ops.go +++ b/libs/databrickscfg/ops.go @@ -47,8 +47,8 @@ func loadOrCreateConfigFile(filename string) (*config.File, error) { func matchOrCreateSection(ctx context.Context, configFile *config.File, cfg *config.Config) (*ini.Section, error) { section, err := findMatchingProfile(configFile, func(s *ini.Section) bool { - if cfg.Profile == s.Name() { - return true + if cfg.Profile != "" { + return cfg.Profile == s.Name() } raw := s.KeysHash() if cfg.AccountID != "" {