Prompt once for a client profile (#727)

## Changes

The previous implementation ran the risk of infinite looping for the
account client due to a mismatch in determining what constitutes an
account client between the CLI and SDK (see
[here](83443bae8d/libs/databrickscfg/profiles.go (L61))
and
[here](0fdc5165e5/config/config.go (L160))).

Ultimately, this code must never infinite loop. If a user is prompted
and selects a profile that cannot be used, they should receive that
feedback immediately and try again, instead of being prompted again.

Related to #726.

## Tests
<!-- How is this tested? -->
This commit is contained in:
Pieter Noordhuis 2023-09-11 17:32:24 +02:00 committed by GitHub
parent 373f441eb2
commit 0cb05d1ded
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 318 additions and 45 deletions

View File

@ -25,13 +25,57 @@ func initProfileFlag(cmd *cobra.Command) {
cmd.RegisterFlagCompletionFunc("profile", databrickscfg.ProfileCompletion) cmd.RegisterFlagCompletionFunc("profile", databrickscfg.ProfileCompletion)
} }
func profileFlagValue(cmd *cobra.Command) (string, bool) {
profileFlag := cmd.Flag("profile")
if profileFlag == nil {
return "", false
}
value := profileFlag.Value.String()
return value, value != ""
}
// Helper function to create an account client or prompt once if the given configuration is not valid.
func accountClientOrPrompt(ctx context.Context, cfg *config.Config, allowPrompt bool) (*databricks.AccountClient, error) {
a, err := databricks.NewAccountClient((*databricks.Config)(cfg))
if err == nil {
err = a.Config.Authenticate(emptyHttpRequest(ctx))
}
prompt := false
if allowPrompt && err != nil && cmdio.IsInteractive(ctx) {
// Prompt to select a profile if the current configuration is not an account client.
prompt = prompt || errors.Is(err, databricks.ErrNotAccountClient)
// Prompt to select a profile if the current configuration doesn't resolve to a credential provider.
prompt = prompt || errors.Is(err, config.ErrCannotConfigureAuth)
}
if !prompt {
// If we are not prompting, we can return early.
return a, err
}
// Try picking a profile dynamically if the current configuration is not valid.
profile, err := askForAccountProfile(ctx)
if err != nil {
return nil, err
}
a, err = databricks.NewAccountClient(&databricks.Config{Profile: profile})
if err == nil {
err = a.Config.Authenticate(emptyHttpRequest(ctx))
if err != nil {
return nil, err
}
}
return a, nil
}
func MustAccountClient(cmd *cobra.Command, args []string) error { func MustAccountClient(cmd *cobra.Command, args []string) error {
cfg := &config.Config{} cfg := &config.Config{}
// command-line flag can specify the profile in use // The command-line profile flag takes precedence over DATABRICKS_CONFIG_PROFILE.
profileFlag := cmd.Flag("profile") profile, hasProfileFlag := profileFlagValue(cmd)
if profileFlag != nil { if hasProfileFlag {
cfg.Profile = profileFlag.Value.String() cfg.Profile = profile
} }
if cfg.Profile == "" { if cfg.Profile == "" {
@ -48,16 +92,8 @@ func MustAccountClient(cmd *cobra.Command, args []string) error {
} }
} }
TRY_AUTH: // or try picking a config profile dynamically allowPrompt := !hasProfileFlag
a, err := databricks.NewAccountClient((*databricks.Config)(cfg)) a, err := accountClientOrPrompt(cmd.Context(), cfg, allowPrompt)
if cmdio.IsInteractive(cmd.Context()) && errors.Is(err, databricks.ErrNotAccountClient) {
profile, err := askForAccountProfile()
if err != nil {
return err
}
cfg = &config.Config{Profile: profile}
goto TRY_AUTH
}
if err != nil { if err != nil {
return err return err
} }
@ -66,13 +102,48 @@ TRY_AUTH: // or try picking a config profile dynamically
return nil return nil
} }
// Helper function to create a workspace client or prompt once if the given configuration is not valid.
func workspaceClientOrPrompt(ctx context.Context, cfg *config.Config, allowPrompt bool) (*databricks.WorkspaceClient, error) {
w, err := databricks.NewWorkspaceClient((*databricks.Config)(cfg))
if err == nil {
err = w.Config.Authenticate(emptyHttpRequest(ctx))
}
prompt := false
if allowPrompt && err != nil && cmdio.IsInteractive(ctx) {
// Prompt to select a profile if the current configuration is not a workspace client.
prompt = prompt || errors.Is(err, databricks.ErrNotWorkspaceClient)
// Prompt to select a profile if the current configuration doesn't resolve to a credential provider.
prompt = prompt || errors.Is(err, config.ErrCannotConfigureAuth)
}
if !prompt {
// If we are not prompting, we can return early.
return w, err
}
// Try picking a profile dynamically if the current configuration is not valid.
profile, err := askForWorkspaceProfile(ctx)
if err != nil {
return nil, err
}
w, err = databricks.NewWorkspaceClient(&databricks.Config{Profile: profile})
if err == nil {
err = w.Config.Authenticate(emptyHttpRequest(ctx))
if err != nil {
return nil, err
}
}
return w, nil
}
func MustWorkspaceClient(cmd *cobra.Command, args []string) error { func MustWorkspaceClient(cmd *cobra.Command, args []string) error {
cfg := &config.Config{} cfg := &config.Config{}
// command-line flag takes precedence over environment variable // The command-line profile flag takes precedence over DATABRICKS_CONFIG_PROFILE.
profileFlag := cmd.Flag("profile") profile, hasProfileFlag := profileFlagValue(cmd)
if profileFlag != nil { if hasProfileFlag {
cfg.Profile = profileFlag.Value.String() cfg.Profile = profile
} }
// try configuring a bundle // try configuring a bundle
@ -87,24 +158,13 @@ func MustWorkspaceClient(cmd *cobra.Command, args []string) error {
cfg = currentBundle.WorkspaceClient().Config cfg = currentBundle.WorkspaceClient().Config
} }
TRY_AUTH: // or try picking a config profile dynamically allowPrompt := !hasProfileFlag
w, err := workspaceClientOrPrompt(cmd.Context(), cfg, allowPrompt)
if err != nil {
return err
}
ctx := cmd.Context() ctx := cmd.Context()
w, err := databricks.NewWorkspaceClient((*databricks.Config)(cfg))
if err != nil {
return err
}
err = w.Config.Authenticate(emptyHttpRequest(ctx))
if cmdio.IsInteractive(ctx) && errors.Is(err, config.ErrCannotConfigureAuth) {
profile, err := askForWorkspaceProfile()
if err != nil {
return err
}
cfg = &config.Config{Profile: profile}
goto TRY_AUTH
}
if err != nil {
return err
}
ctx = context.WithValue(ctx, &workspaceClient, w) ctx = context.WithValue(ctx, &workspaceClient, w)
cmd.SetContext(ctx) cmd.SetContext(ctx)
return nil return nil
@ -121,7 +181,7 @@ func transformLoadError(path string, err error) error {
return err return err
} }
func askForWorkspaceProfile() (string, error) { func askForWorkspaceProfile(ctx context.Context) (string, error) {
path, err := databrickscfg.GetPath() path, err := databrickscfg.GetPath()
if err != nil { if err != nil {
return "", fmt.Errorf("cannot determine Databricks config file path: %w", err) return "", fmt.Errorf("cannot determine Databricks config file path: %w", err)
@ -136,7 +196,7 @@ func askForWorkspaceProfile() (string, error) {
case 1: case 1:
return profiles[0].Name, nil return profiles[0].Name, nil
} }
i, _, err := (&promptui.Select{ i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
Label: fmt.Sprintf("Workspace profiles defined in %s", file), Label: fmt.Sprintf("Workspace profiles defined in %s", file),
Items: profiles, Items: profiles,
Searcher: profiles.SearchCaseInsensitive, Searcher: profiles.SearchCaseInsensitive,
@ -147,16 +207,14 @@ func askForWorkspaceProfile() (string, error) {
Inactive: `{{.Name}}`, Inactive: `{{.Name}}`,
Selected: `{{ "Using workspace profile" | faint }}: {{ .Name | bold }}`, Selected: `{{ "Using workspace profile" | faint }}: {{ .Name | bold }}`,
}, },
Stdin: os.Stdin, })
Stdout: os.Stderr,
}).Run()
if err != nil { if err != nil {
return "", err return "", err
} }
return profiles[i].Name, nil return profiles[i].Name, nil
} }
func askForAccountProfile() (string, error) { func askForAccountProfile(ctx context.Context) (string, error) {
path, err := databrickscfg.GetPath() path, err := databrickscfg.GetPath()
if err != nil { if err != nil {
return "", fmt.Errorf("cannot determine Databricks config file path: %w", err) return "", fmt.Errorf("cannot determine Databricks config file path: %w", err)
@ -171,7 +229,7 @@ func askForAccountProfile() (string, error) {
case 1: case 1:
return profiles[0].Name, nil return profiles[0].Name, nil
} }
i, _, err := (&promptui.Select{ i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
Label: fmt.Sprintf("Account profiles defined in %s", file), Label: fmt.Sprintf("Account profiles defined in %s", file),
Items: profiles, Items: profiles,
Searcher: profiles.SearchCaseInsensitive, Searcher: profiles.SearchCaseInsensitive,
@ -182,9 +240,7 @@ func askForAccountProfile() (string, error) {
Inactive: `{{.Name}}`, Inactive: `{{.Name}}`,
Selected: `{{ "Using account profile" | faint }}: {{ .Name | bold }}`, Selected: `{{ "Using account profile" | faint }}: {{ .Name | bold }}`,
}, },
Stdin: os.Stdin, })
Stdout: os.Stderr,
}).Run()
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -2,9 +2,15 @@ package root
import ( import (
"context" "context"
"os"
"path/filepath"
"testing" "testing"
"time"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/databricks-sdk-go/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestEmptyHttpRequest(t *testing.T) { func TestEmptyHttpRequest(t *testing.T) {
@ -12,3 +18,161 @@ func TestEmptyHttpRequest(t *testing.T) {
req := emptyHttpRequest(ctx) req := emptyHttpRequest(ctx)
assert.Equal(t, req.Context(), ctx) assert.Equal(t, req.Context(), ctx)
} }
type promptFn func(ctx context.Context, cfg *config.Config, retry bool) (any, error)
var accountPromptFn = func(ctx context.Context, cfg *config.Config, retry bool) (any, error) {
return accountClientOrPrompt(ctx, cfg, retry)
}
var workspacePromptFn = func(ctx context.Context, cfg *config.Config, retry bool) (any, error) {
return workspaceClientOrPrompt(ctx, cfg, retry)
}
func expectPrompts(t *testing.T, fn promptFn, config *config.Config) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
// Channel to pass errors from the prompting function back to the test.
errch := make(chan error, 1)
ctx, io := cmdio.SetupTest(ctx)
go func() {
defer close(errch)
defer cancel()
_, err := fn(ctx, config, true)
errch <- err
}()
// Expect a prompt
line, _, err := io.Stderr.ReadLine()
if assert.NoError(t, err, "Expected to read a line from stderr") {
assert.Contains(t, string(line), "Search:")
} else {
// If there was an error reading from stderr, the prompting function must have terminated early.
assert.NoError(t, <-errch)
}
}
func expectReturns(t *testing.T, fn promptFn, config *config.Config) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
ctx, _ = cmdio.SetupTest(ctx)
client, err := fn(ctx, config, true)
require.NoError(t, err)
require.NotNil(t, client)
}
func TestAccountClientOrPrompt(t *testing.T) {
dir := t.TempDir()
configFile := filepath.Join(dir, ".databrickscfg")
err := os.WriteFile(
configFile,
[]byte(`
[account-1111]
host = https://accounts.azuredatabricks.net/
account_id = 1111
token = foobar
[account-1112]
host = https://accounts.azuredatabricks.net/
account_id = 1112
token = foobar
`),
0755)
require.NoError(t, err)
t.Setenv("DATABRICKS_CONFIG_FILE", configFile)
t.Setenv("PATH", "/nothing")
t.Run("Prompt if nothing is specified", func(t *testing.T) {
expectPrompts(t, accountPromptFn, &config.Config{})
})
t.Run("Prompt if a workspace host is specified", func(t *testing.T) {
expectPrompts(t, accountPromptFn, &config.Config{
Host: "https://adb-1234567.89.azuredatabricks.net/",
AccountID: "1234",
Token: "foobar",
})
})
t.Run("Prompt if account ID is not specified", func(t *testing.T) {
expectPrompts(t, accountPromptFn, &config.Config{
Host: "https://accounts.azuredatabricks.net/",
Token: "foobar",
})
})
t.Run("Prompt if no credential provider can be configured", func(t *testing.T) {
expectPrompts(t, accountPromptFn, &config.Config{
Host: "https://accounts.azuredatabricks.net/",
AccountID: "1234",
})
})
t.Run("Returns if configuration is valid", func(t *testing.T) {
expectReturns(t, accountPromptFn, &config.Config{
Host: "https://accounts.azuredatabricks.net/",
AccountID: "1234",
Token: "foobar",
})
})
t.Run("Returns if a valid profile is specified", func(t *testing.T) {
expectReturns(t, accountPromptFn, &config.Config{
Profile: "account-1111",
})
})
}
func TestWorkspaceClientOrPrompt(t *testing.T) {
dir := t.TempDir()
configFile := filepath.Join(dir, ".databrickscfg")
err := os.WriteFile(
configFile,
[]byte(`
[workspace-1111]
host = https://adb-1111.11.azuredatabricks.net/
token = foobar
[workspace-1112]
host = https://adb-1112.12.azuredatabricks.net/
token = foobar
`),
0755)
require.NoError(t, err)
t.Setenv("DATABRICKS_CONFIG_FILE", configFile)
t.Setenv("PATH", "/nothing")
t.Run("Prompt if nothing is specified", func(t *testing.T) {
expectPrompts(t, workspacePromptFn, &config.Config{})
})
t.Run("Prompt if an account host is specified", func(t *testing.T) {
expectPrompts(t, workspacePromptFn, &config.Config{
Host: "https://accounts.azuredatabricks.net/",
AccountID: "1234",
Token: "foobar",
})
})
t.Run("Prompt if no credential provider can be configured", func(t *testing.T) {
expectPrompts(t, workspacePromptFn, &config.Config{
Host: "https://adb-1111.11.azuredatabricks.net/",
})
})
t.Run("Returns if configuration is valid", func(t *testing.T) {
expectReturns(t, workspacePromptFn, &config.Config{
Host: "https://adb-1111.11.azuredatabricks.net/",
Token: "foobar",
})
})
t.Run("Returns if a valid profile is specified", func(t *testing.T) {
expectReturns(t, workspacePromptFn, &config.Config{
Profile: "workspace-1111",
})
})
}

View File

@ -205,6 +205,13 @@ func Prompt(ctx context.Context) *promptui.Prompt {
} }
} }
func RunSelect(ctx context.Context, prompt *promptui.Select) (int, string, error) {
c := fromContext(ctx)
prompt.Stdin = io.NopCloser(c.in)
prompt.Stdout = nopWriteCloser{c.err}
return prompt.Run()
}
func (c *cmdIO) simplePrompt(label string) *promptui.Prompt { func (c *cmdIO) simplePrompt(label string) *promptui.Prompt {
return &promptui.Prompt{ return &promptui.Prompt{
Label: label, Label: label,

46
libs/cmdio/testing.go Normal file
View File

@ -0,0 +1,46 @@
package cmdio
import (
"bufio"
"context"
"io"
)
type Test struct {
Done context.CancelFunc
Stdin *bufio.Writer
Stdout *bufio.Reader
Stderr *bufio.Reader
}
func SetupTest(ctx context.Context) (context.Context, *Test) {
rin, win := io.Pipe()
rout, wout := io.Pipe()
rerr, werr := io.Pipe()
cmdio := &cmdIO{
interactive: true,
in: rin,
out: wout,
err: werr,
}
ctx, cancel := context.WithCancel(ctx)
ctx = InContext(ctx, cmdio)
// Wait for context to be done, so we can drain stdin and close the pipes.
go func() {
<-ctx.Done()
rin.Close()
wout.Close()
werr.Close()
}()
return ctx, &Test{
Done: cancel,
Stdin: bufio.NewWriter(win),
Stdout: bufio.NewReader(rout),
Stderr: bufio.NewReader(rerr),
}
}