Fix host resolution order in `auth login` (#1370)

## Changes
The `auth login` command today prefers a host URL specified in a profile
before selecting the one explicitly provided by a user as a command line
argument.

This PR fixes this bug and refactors the code to make it more linear and
easy to read. Note that the same issue exists in the `auth token`
command and is fixed here as well.

## Tests
Unit tests, and manual testing.
This commit is contained in:
shreyas-goenka 2024-08-14 18:31:00 +05:30 committed by GitHub
parent 48ff18e5fc
commit 1225fc0c13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 137 additions and 38 deletions

View File

@ -2,6 +2,7 @@ package auth
import ( import (
"context" "context"
"fmt"
"github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/auth"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
@ -34,25 +35,23 @@ GCP: https://docs.gcp.databricks.com/dev-tools/auth/index.html`,
} }
func promptForHost(ctx context.Context) (string, error) { func promptForHost(ctx context.Context) (string, error) {
prompt := cmdio.Prompt(ctx) if !cmdio.IsInTTY(ctx) {
prompt.Label = "Databricks Host (e.g. https://<databricks-instance>.cloud.databricks.com)" return "", fmt.Errorf("the command is being run in a non-interactive environment, please specify a host using --host")
// Validate?
host, err := prompt.Run()
if err != nil {
return "", err
} }
return host, nil
prompt := cmdio.Prompt(ctx)
prompt.Label = "Databricks host (e.g. https://<databricks-instance>.cloud.databricks.com)"
return prompt.Run()
} }
func promptForAccountID(ctx context.Context) (string, error) { func promptForAccountID(ctx context.Context) (string, error) {
if !cmdio.IsInTTY(ctx) {
return "", fmt.Errorf("the command is being run in a non-interactive environment, please specify an account ID using --account-id")
}
prompt := cmdio.Prompt(ctx) prompt := cmdio.Prompt(ctx)
prompt.Label = "Databricks Account ID" prompt.Label = "Databricks account ID"
prompt.Default = "" prompt.Default = ""
prompt.AllowEdit = true prompt.AllowEdit = true
// Validate? return prompt.Run()
accountId, err := prompt.Run()
if err != nil {
return "", err
}
return accountId, nil
} }

View File

@ -17,18 +17,16 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
func configureHost(ctx context.Context, persistentAuth *auth.PersistentAuth, args []string, argIndex int) error { func promptForProfile(ctx context.Context, defaultValue string) (string, error) {
if len(args) > argIndex { if !cmdio.IsInTTY(ctx) {
persistentAuth.Host = args[argIndex] return "", fmt.Errorf("the command is being run in a non-interactive environment, please specify a profile using --profile")
return nil
} }
host, err := promptForHost(ctx) prompt := cmdio.Prompt(ctx)
if err != nil { prompt.Label = "Databricks profile name"
return err prompt.Default = defaultValue
} prompt.AllowEdit = true
persistentAuth.Host = host return prompt.Run()
return nil
} }
const minimalDbConnectVersion = "13.1" const minimalDbConnectVersion = "13.1"
@ -93,23 +91,18 @@ depends on the existing profiles you have set in your configuration file
cmd.RunE = func(cmd *cobra.Command, args []string) error { cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context() ctx := cmd.Context()
profileName := cmd.Flag("profile").Value.String()
var profileName string // If the user has not specified a profile name, prompt for one.
profileFlag := cmd.Flag("profile") if profileName == "" {
if profileFlag != nil && profileFlag.Value.String() != "" { var err error
profileName = profileFlag.Value.String() profileName, err = promptForProfile(ctx, persistentAuth.ProfileName())
} else if cmdio.IsInTTY(ctx) {
prompt := cmdio.Prompt(ctx)
prompt.Label = "Databricks Profile Name"
prompt.Default = persistentAuth.ProfileName()
prompt.AllowEdit = true
profile, err := prompt.Run()
if err != nil { if err != nil {
return err return err
} }
profileName = profile
} }
// Set the host and account-id based on the provided arguments and flags.
err := setHostAndAccountId(ctx, profileName, persistentAuth, args) err := setHostAndAccountId(ctx, profileName, persistentAuth, args)
if err != nil { if err != nil {
return err return err
@ -167,7 +160,23 @@ depends on the existing profiles you have set in your configuration file
return cmd return cmd
} }
// Sets the host in the persistentAuth object based on the provided arguments and flags.
// Follows the following precedence:
// 1. [HOST] (first positional argument) or --host flag. Error if both are specified.
// 2. Profile host, if available.
// 3. Prompt the user for the host.
//
// Set the account in the persistentAuth object based on the flags.
// Follows the following precedence:
// 1. --account-id flag.
// 2. account-id from the specified profile, if available.
// 3. Prompt the user for the account-id.
func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error { func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error {
// If both [HOST] and --host are provided, return an error.
if len(args) > 0 && persistentAuth.Host != "" {
return fmt.Errorf("please only provide a host as an argument or a flag, not both")
}
profiler := profile.GetProfiler(ctx) profiler := profile.GetProfiler(ctx)
// If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile. // If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile.
profiles, err := profiler.LoadProfiles(ctx, profile.WithName(profileName)) profiles, err := profiler.LoadProfiles(ctx, profile.WithName(profileName))
@ -177,17 +186,32 @@ func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth
} }
if persistentAuth.Host == "" { if persistentAuth.Host == "" {
if len(profiles) > 0 && profiles[0].Host != "" { if len(args) > 0 {
// If [HOST] is provided, set the host to the provided positional argument.
persistentAuth.Host = args[0]
} else if len(profiles) > 0 && profiles[0].Host != "" {
// If neither [HOST] nor --host are provided, and the profile has a host, use it.
persistentAuth.Host = profiles[0].Host persistentAuth.Host = profiles[0].Host
} else { } else {
configureHost(ctx, persistentAuth, args, 0) // If neither [HOST] nor --host are provided, and the profile does not have a host,
// then prompt the user for a host.
hostName, err := promptForHost(ctx)
if err != nil {
return err
}
persistentAuth.Host = hostName
} }
} }
// If the account-id was not provided as a cmd line flag, try to read it from
// the specified profile.
isAccountClient := (&config.Config{Host: persistentAuth.Host}).IsAccountClient() isAccountClient := (&config.Config{Host: persistentAuth.Host}).IsAccountClient()
if isAccountClient && persistentAuth.AccountID == "" { if isAccountClient && persistentAuth.AccountID == "" {
if len(profiles) > 0 && profiles[0].AccountID != "" { if len(profiles) > 0 && profiles[0].AccountID != "" {
persistentAuth.AccountID = profiles[0].AccountID persistentAuth.AccountID = profiles[0].AccountID
} else { } else {
// Prompt user for the account-id if it we could not get it from a
// profile.
accountId, err := promptForAccountID(ctx) accountId, err := promptForAccountID(ctx)
if err != nil { if err != nil {
return err return err

View File

@ -5,8 +5,10 @@ import (
"testing" "testing"
"github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/auth"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/env"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) { func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) {
@ -15,3 +17,69 @@ func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) {
err := setHostAndAccountId(ctx, "foo", &auth.PersistentAuth{Host: "test"}, []string{}) err := setHostAndAccountId(ctx, "foo", &auth.PersistentAuth{Host: "test"}, []string{})
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestSetHost(t *testing.T) {
var persistentAuth auth.PersistentAuth
t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg")
ctx, _ := cmdio.SetupTest(context.Background())
// Test error when both flag and argument are provided
persistentAuth.Host = "val from --host"
err := setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"})
assert.EqualError(t, err, "please only provide a host as an argument or a flag, not both")
// Test setting host from flag
persistentAuth.Host = "val from --host"
err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{})
assert.NoError(t, err)
assert.Equal(t, "val from --host", persistentAuth.Host)
// Test setting host from argument
persistentAuth.Host = ""
err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"})
assert.NoError(t, err)
assert.Equal(t, "val from [HOST]", persistentAuth.Host)
// Test setting host from profile
persistentAuth.Host = ""
err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{})
assert.NoError(t, err)
assert.Equal(t, "https://www.host1.com", persistentAuth.Host)
// Test setting host from profile
persistentAuth.Host = ""
err = setHostAndAccountId(ctx, "profile-2", &persistentAuth, []string{})
assert.NoError(t, err)
assert.Equal(t, "https://www.host2.com", persistentAuth.Host)
// Test host is not set. Should prompt.
persistentAuth.Host = ""
err = setHostAndAccountId(ctx, "", &persistentAuth, []string{})
assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify a host using --host")
}
func TestSetAccountId(t *testing.T) {
var persistentAuth auth.PersistentAuth
t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg")
ctx, _ := cmdio.SetupTest(context.Background())
// Test setting account-id from flag
persistentAuth.AccountID = "val from --account-id"
err := setHostAndAccountId(ctx, "account-profile", &persistentAuth, []string{})
assert.NoError(t, err)
assert.Equal(t, "https://accounts.cloud.databricks.com", persistentAuth.Host)
assert.Equal(t, "val from --account-id", persistentAuth.AccountID)
// Test setting account_id from profile
persistentAuth.AccountID = ""
err = setHostAndAccountId(ctx, "account-profile", &persistentAuth, []string{})
require.NoError(t, err)
assert.Equal(t, "https://accounts.cloud.databricks.com", persistentAuth.Host)
assert.Equal(t, "id-from-profile", persistentAuth.AccountID)
// Neither flag nor profile account-id is set, should prompt
persistentAuth.AccountID = ""
persistentAuth.Host = "https://accounts.cloud.databricks.com"
err = setHostAndAccountId(ctx, "", &persistentAuth, []string{})
assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify an account ID using --account-id")
}

9
cmd/auth/testdata/.databrickscfg vendored Normal file
View File

@ -0,0 +1,9 @@
[profile-1]
host = https://www.host1.com
[profile-2]
host = https://www.host2.com
[account-profile]
host = https://accounts.cloud.databricks.com
account_id = id-from-profile

View File

@ -105,7 +105,6 @@ func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) {
} }
func (a *PersistentAuth) ProfileName() string { func (a *PersistentAuth) ProfileName() string {
// TODO: get profile name from interactive input
if a.AccountID != "" { if a.AccountID != "" {
return fmt.Sprintf("ACCOUNT-%s", a.AccountID) return fmt.Sprintf("ACCOUNT-%s", a.AccountID)
} }