From 1225fc0c13edde6cd267d7523dd4bfa00621fa82 Mon Sep 17 00:00:00 2001 From: shreyas-goenka <88374338+shreyas-goenka@users.noreply.github.com> Date: Wed, 14 Aug 2024 18:31:00 +0530 Subject: [PATCH] Fix host resolution order in `auth login` (#1370) ## Changes The `auth login` command today prefers a host URL specified in a profile before selecting the one explicitly provided by a user as a command line argument. This PR fixes this bug and refactors the code to make it more linear and easy to read. Note that the same issue exists in the `auth token` command and is fixed here as well. ## Tests Unit tests, and manual testing. --- cmd/auth/auth.go | 27 ++++++------ cmd/auth/login.go | 70 +++++++++++++++++++++----------- cmd/auth/login_test.go | 68 +++++++++++++++++++++++++++++++ cmd/auth/testdata/.databrickscfg | 9 ++++ libs/auth/oauth.go | 1 - 5 files changed, 137 insertions(+), 38 deletions(-) create mode 100644 cmd/auth/testdata/.databrickscfg diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index 79e1063b..ceceae25 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -2,6 +2,7 @@ package auth import ( "context" + "fmt" "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/cmdio" @@ -34,25 +35,23 @@ GCP: https://docs.gcp.databricks.com/dev-tools/auth/index.html`, } func promptForHost(ctx context.Context) (string, error) { - prompt := cmdio.Prompt(ctx) - prompt.Label = "Databricks Host (e.g. https://.cloud.databricks.com)" - // Validate? - host, err := prompt.Run() - if err != nil { - return "", err + if !cmdio.IsInTTY(ctx) { + return "", fmt.Errorf("the command is being run in a non-interactive environment, please specify a host using --host") } - return host, nil + + prompt := cmdio.Prompt(ctx) + prompt.Label = "Databricks host (e.g. https://.cloud.databricks.com)" + return prompt.Run() } func promptForAccountID(ctx context.Context) (string, error) { + if !cmdio.IsInTTY(ctx) { + return "", fmt.Errorf("the command is being run in a non-interactive environment, please specify an account ID using --account-id") + } + prompt := cmdio.Prompt(ctx) - prompt.Label = "Databricks Account ID" + prompt.Label = "Databricks account ID" prompt.Default = "" prompt.AllowEdit = true - // Validate? - accountId, err := prompt.Run() - if err != nil { - return "", err - } - return accountId, nil + return prompt.Run() } diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 11cba8e5..f87a2a02 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -17,18 +17,16 @@ import ( "github.com/spf13/cobra" ) -func configureHost(ctx context.Context, persistentAuth *auth.PersistentAuth, args []string, argIndex int) error { - if len(args) > argIndex { - persistentAuth.Host = args[argIndex] - return nil +func promptForProfile(ctx context.Context, defaultValue string) (string, error) { + if !cmdio.IsInTTY(ctx) { + return "", fmt.Errorf("the command is being run in a non-interactive environment, please specify a profile using --profile") } - host, err := promptForHost(ctx) - if err != nil { - return err - } - persistentAuth.Host = host - return nil + prompt := cmdio.Prompt(ctx) + prompt.Label = "Databricks profile name" + prompt.Default = defaultValue + prompt.AllowEdit = true + return prompt.Run() } const minimalDbConnectVersion = "13.1" @@ -93,23 +91,18 @@ depends on the existing profiles you have set in your configuration file cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() + profileName := cmd.Flag("profile").Value.String() - var profileName string - profileFlag := cmd.Flag("profile") - if profileFlag != nil && profileFlag.Value.String() != "" { - profileName = profileFlag.Value.String() - } else if cmdio.IsInTTY(ctx) { - prompt := cmdio.Prompt(ctx) - prompt.Label = "Databricks Profile Name" - prompt.Default = persistentAuth.ProfileName() - prompt.AllowEdit = true - profile, err := prompt.Run() + // If the user has not specified a profile name, prompt for one. + if profileName == "" { + var err error + profileName, err = promptForProfile(ctx, persistentAuth.ProfileName()) if err != nil { return err } - profileName = profile } + // Set the host and account-id based on the provided arguments and flags. err := setHostAndAccountId(ctx, profileName, persistentAuth, args) if err != nil { return err @@ -167,7 +160,23 @@ depends on the existing profiles you have set in your configuration file return cmd } +// Sets the host in the persistentAuth object based on the provided arguments and flags. +// Follows the following precedence: +// 1. [HOST] (first positional argument) or --host flag. Error if both are specified. +// 2. Profile host, if available. +// 3. Prompt the user for the host. +// +// Set the account in the persistentAuth object based on the flags. +// Follows the following precedence: +// 1. --account-id flag. +// 2. account-id from the specified profile, if available. +// 3. Prompt the user for the account-id. func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error { + // If both [HOST] and --host are provided, return an error. + if len(args) > 0 && persistentAuth.Host != "" { + return fmt.Errorf("please only provide a host as an argument or a flag, not both") + } + profiler := profile.GetProfiler(ctx) // If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile. profiles, err := profiler.LoadProfiles(ctx, profile.WithName(profileName)) @@ -177,17 +186,32 @@ func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth } if persistentAuth.Host == "" { - if len(profiles) > 0 && profiles[0].Host != "" { + if len(args) > 0 { + // If [HOST] is provided, set the host to the provided positional argument. + persistentAuth.Host = args[0] + } else if len(profiles) > 0 && profiles[0].Host != "" { + // If neither [HOST] nor --host are provided, and the profile has a host, use it. persistentAuth.Host = profiles[0].Host } else { - configureHost(ctx, persistentAuth, args, 0) + // If neither [HOST] nor --host are provided, and the profile does not have a host, + // then prompt the user for a host. + hostName, err := promptForHost(ctx) + if err != nil { + return err + } + persistentAuth.Host = hostName } } + + // If the account-id was not provided as a cmd line flag, try to read it from + // the specified profile. isAccountClient := (&config.Config{Host: persistentAuth.Host}).IsAccountClient() if isAccountClient && persistentAuth.AccountID == "" { if len(profiles) > 0 && profiles[0].AccountID != "" { persistentAuth.AccountID = profiles[0].AccountID } else { + // Prompt user for the account-id if it we could not get it from a + // profile. accountId, err := promptForAccountID(ctx) if err != nil { return err diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index ce3ca5ae..d0fa5a16 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -5,8 +5,10 @@ import ( "testing" "github.com/databricks/cli/libs/auth" + "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/env" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) { @@ -15,3 +17,69 @@ func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) { err := setHostAndAccountId(ctx, "foo", &auth.PersistentAuth{Host: "test"}, []string{}) assert.NoError(t, err) } + +func TestSetHost(t *testing.T) { + var persistentAuth auth.PersistentAuth + t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg") + ctx, _ := cmdio.SetupTest(context.Background()) + + // Test error when both flag and argument are provided + persistentAuth.Host = "val from --host" + err := setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"}) + assert.EqualError(t, err, "please only provide a host as an argument or a flag, not both") + + // Test setting host from flag + persistentAuth.Host = "val from --host" + err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{}) + assert.NoError(t, err) + assert.Equal(t, "val from --host", persistentAuth.Host) + + // Test setting host from argument + persistentAuth.Host = "" + err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"}) + assert.NoError(t, err) + assert.Equal(t, "val from [HOST]", persistentAuth.Host) + + // Test setting host from profile + persistentAuth.Host = "" + err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{}) + assert.NoError(t, err) + assert.Equal(t, "https://www.host1.com", persistentAuth.Host) + + // Test setting host from profile + persistentAuth.Host = "" + err = setHostAndAccountId(ctx, "profile-2", &persistentAuth, []string{}) + assert.NoError(t, err) + assert.Equal(t, "https://www.host2.com", persistentAuth.Host) + + // Test host is not set. Should prompt. + persistentAuth.Host = "" + err = setHostAndAccountId(ctx, "", &persistentAuth, []string{}) + assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify a host using --host") +} + +func TestSetAccountId(t *testing.T) { + var persistentAuth auth.PersistentAuth + t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg") + ctx, _ := cmdio.SetupTest(context.Background()) + + // Test setting account-id from flag + persistentAuth.AccountID = "val from --account-id" + err := setHostAndAccountId(ctx, "account-profile", &persistentAuth, []string{}) + assert.NoError(t, err) + assert.Equal(t, "https://accounts.cloud.databricks.com", persistentAuth.Host) + assert.Equal(t, "val from --account-id", persistentAuth.AccountID) + + // Test setting account_id from profile + persistentAuth.AccountID = "" + err = setHostAndAccountId(ctx, "account-profile", &persistentAuth, []string{}) + require.NoError(t, err) + assert.Equal(t, "https://accounts.cloud.databricks.com", persistentAuth.Host) + assert.Equal(t, "id-from-profile", persistentAuth.AccountID) + + // Neither flag nor profile account-id is set, should prompt + persistentAuth.AccountID = "" + persistentAuth.Host = "https://accounts.cloud.databricks.com" + err = setHostAndAccountId(ctx, "", &persistentAuth, []string{}) + assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify an account ID using --account-id") +} diff --git a/cmd/auth/testdata/.databrickscfg b/cmd/auth/testdata/.databrickscfg new file mode 100644 index 00000000..06e55224 --- /dev/null +++ b/cmd/auth/testdata/.databrickscfg @@ -0,0 +1,9 @@ +[profile-1] +host = https://www.host1.com + +[profile-2] +host = https://www.host2.com + +[account-profile] +host = https://accounts.cloud.databricks.com +account_id = id-from-profile diff --git a/libs/auth/oauth.go b/libs/auth/oauth.go index 1f3e032d..7c1cb957 100644 --- a/libs/auth/oauth.go +++ b/libs/auth/oauth.go @@ -105,7 +105,6 @@ func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) { } func (a *PersistentAuth) ProfileName() string { - // TODO: get profile name from interactive input if a.AccountID != "" { return fmt.Sprintf("ACCOUNT-%s", a.AccountID) }