mirror of https://github.com/databricks/cli.git
Prompt once for a client profile (#727)
## Changes The previous implementation ran the risk of infinite looping for the account client due to a mismatch in determining what constitutes an account client between the CLI and SDK (see [here](83443bae8d/libs/databrickscfg/profiles.go (L61)
) and [here](0fdc5165e5/config/config.go (L160)
)). Ultimately, this code must never infinite loop. If a user is prompted and selects a profile that cannot be used, they should receive that feedback immediately and try again, instead of being prompted again. Related to #726. ## Tests <!-- How is this tested? -->
This commit is contained in:
parent
373f441eb2
commit
0cb05d1ded
146
cmd/root/auth.go
146
cmd/root/auth.go
|
@ -25,13 +25,57 @@ func initProfileFlag(cmd *cobra.Command) {
|
|||
cmd.RegisterFlagCompletionFunc("profile", databrickscfg.ProfileCompletion)
|
||||
}
|
||||
|
||||
func profileFlagValue(cmd *cobra.Command) (string, bool) {
|
||||
profileFlag := cmd.Flag("profile")
|
||||
if profileFlag == nil {
|
||||
return "", false
|
||||
}
|
||||
value := profileFlag.Value.String()
|
||||
return value, value != ""
|
||||
}
|
||||
|
||||
// Helper function to create an account client or prompt once if the given configuration is not valid.
|
||||
func accountClientOrPrompt(ctx context.Context, cfg *config.Config, allowPrompt bool) (*databricks.AccountClient, error) {
|
||||
a, err := databricks.NewAccountClient((*databricks.Config)(cfg))
|
||||
if err == nil {
|
||||
err = a.Config.Authenticate(emptyHttpRequest(ctx))
|
||||
}
|
||||
|
||||
prompt := false
|
||||
if allowPrompt && err != nil && cmdio.IsInteractive(ctx) {
|
||||
// Prompt to select a profile if the current configuration is not an account client.
|
||||
prompt = prompt || errors.Is(err, databricks.ErrNotAccountClient)
|
||||
// Prompt to select a profile if the current configuration doesn't resolve to a credential provider.
|
||||
prompt = prompt || errors.Is(err, config.ErrCannotConfigureAuth)
|
||||
}
|
||||
|
||||
if !prompt {
|
||||
// If we are not prompting, we can return early.
|
||||
return a, err
|
||||
}
|
||||
|
||||
// Try picking a profile dynamically if the current configuration is not valid.
|
||||
profile, err := askForAccountProfile(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a, err = databricks.NewAccountClient(&databricks.Config{Profile: profile})
|
||||
if err == nil {
|
||||
err = a.Config.Authenticate(emptyHttpRequest(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func MustAccountClient(cmd *cobra.Command, args []string) error {
|
||||
cfg := &config.Config{}
|
||||
|
||||
// command-line flag can specify the profile in use
|
||||
profileFlag := cmd.Flag("profile")
|
||||
if profileFlag != nil {
|
||||
cfg.Profile = profileFlag.Value.String()
|
||||
// The command-line profile flag takes precedence over DATABRICKS_CONFIG_PROFILE.
|
||||
profile, hasProfileFlag := profileFlagValue(cmd)
|
||||
if hasProfileFlag {
|
||||
cfg.Profile = profile
|
||||
}
|
||||
|
||||
if cfg.Profile == "" {
|
||||
|
@ -48,16 +92,8 @@ func MustAccountClient(cmd *cobra.Command, args []string) error {
|
|||
}
|
||||
}
|
||||
|
||||
TRY_AUTH: // or try picking a config profile dynamically
|
||||
a, err := databricks.NewAccountClient((*databricks.Config)(cfg))
|
||||
if cmdio.IsInteractive(cmd.Context()) && errors.Is(err, databricks.ErrNotAccountClient) {
|
||||
profile, err := askForAccountProfile()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg = &config.Config{Profile: profile}
|
||||
goto TRY_AUTH
|
||||
}
|
||||
allowPrompt := !hasProfileFlag
|
||||
a, err := accountClientOrPrompt(cmd.Context(), cfg, allowPrompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -66,13 +102,48 @@ TRY_AUTH: // or try picking a config profile dynamically
|
|||
return nil
|
||||
}
|
||||
|
||||
// Helper function to create a workspace client or prompt once if the given configuration is not valid.
|
||||
func workspaceClientOrPrompt(ctx context.Context, cfg *config.Config, allowPrompt bool) (*databricks.WorkspaceClient, error) {
|
||||
w, err := databricks.NewWorkspaceClient((*databricks.Config)(cfg))
|
||||
if err == nil {
|
||||
err = w.Config.Authenticate(emptyHttpRequest(ctx))
|
||||
}
|
||||
|
||||
prompt := false
|
||||
if allowPrompt && err != nil && cmdio.IsInteractive(ctx) {
|
||||
// Prompt to select a profile if the current configuration is not a workspace client.
|
||||
prompt = prompt || errors.Is(err, databricks.ErrNotWorkspaceClient)
|
||||
// Prompt to select a profile if the current configuration doesn't resolve to a credential provider.
|
||||
prompt = prompt || errors.Is(err, config.ErrCannotConfigureAuth)
|
||||
}
|
||||
|
||||
if !prompt {
|
||||
// If we are not prompting, we can return early.
|
||||
return w, err
|
||||
}
|
||||
|
||||
// Try picking a profile dynamically if the current configuration is not valid.
|
||||
profile, err := askForWorkspaceProfile(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w, err = databricks.NewWorkspaceClient(&databricks.Config{Profile: profile})
|
||||
if err == nil {
|
||||
err = w.Config.Authenticate(emptyHttpRequest(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return w, nil
|
||||
}
|
||||
|
||||
func MustWorkspaceClient(cmd *cobra.Command, args []string) error {
|
||||
cfg := &config.Config{}
|
||||
|
||||
// command-line flag takes precedence over environment variable
|
||||
profileFlag := cmd.Flag("profile")
|
||||
if profileFlag != nil {
|
||||
cfg.Profile = profileFlag.Value.String()
|
||||
// The command-line profile flag takes precedence over DATABRICKS_CONFIG_PROFILE.
|
||||
profile, hasProfileFlag := profileFlagValue(cmd)
|
||||
if hasProfileFlag {
|
||||
cfg.Profile = profile
|
||||
}
|
||||
|
||||
// try configuring a bundle
|
||||
|
@ -87,24 +158,13 @@ func MustWorkspaceClient(cmd *cobra.Command, args []string) error {
|
|||
cfg = currentBundle.WorkspaceClient().Config
|
||||
}
|
||||
|
||||
TRY_AUTH: // or try picking a config profile dynamically
|
||||
allowPrompt := !hasProfileFlag
|
||||
w, err := workspaceClientOrPrompt(cmd.Context(), cfg, allowPrompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := cmd.Context()
|
||||
w, err := databricks.NewWorkspaceClient((*databricks.Config)(cfg))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = w.Config.Authenticate(emptyHttpRequest(ctx))
|
||||
if cmdio.IsInteractive(ctx) && errors.Is(err, config.ErrCannotConfigureAuth) {
|
||||
profile, err := askForWorkspaceProfile()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg = &config.Config{Profile: profile}
|
||||
goto TRY_AUTH
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx = context.WithValue(ctx, &workspaceClient, w)
|
||||
cmd.SetContext(ctx)
|
||||
return nil
|
||||
|
@ -121,7 +181,7 @@ func transformLoadError(path string, err error) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func askForWorkspaceProfile() (string, error) {
|
||||
func askForWorkspaceProfile(ctx context.Context) (string, error) {
|
||||
path, err := databrickscfg.GetPath()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot determine Databricks config file path: %w", err)
|
||||
|
@ -136,7 +196,7 @@ func askForWorkspaceProfile() (string, error) {
|
|||
case 1:
|
||||
return profiles[0].Name, nil
|
||||
}
|
||||
i, _, err := (&promptui.Select{
|
||||
i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
|
||||
Label: fmt.Sprintf("Workspace profiles defined in %s", file),
|
||||
Items: profiles,
|
||||
Searcher: profiles.SearchCaseInsensitive,
|
||||
|
@ -147,16 +207,14 @@ func askForWorkspaceProfile() (string, error) {
|
|||
Inactive: `{{.Name}}`,
|
||||
Selected: `{{ "Using workspace profile" | faint }}: {{ .Name | bold }}`,
|
||||
},
|
||||
Stdin: os.Stdin,
|
||||
Stdout: os.Stderr,
|
||||
}).Run()
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return profiles[i].Name, nil
|
||||
}
|
||||
|
||||
func askForAccountProfile() (string, error) {
|
||||
func askForAccountProfile(ctx context.Context) (string, error) {
|
||||
path, err := databrickscfg.GetPath()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot determine Databricks config file path: %w", err)
|
||||
|
@ -171,7 +229,7 @@ func askForAccountProfile() (string, error) {
|
|||
case 1:
|
||||
return profiles[0].Name, nil
|
||||
}
|
||||
i, _, err := (&promptui.Select{
|
||||
i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
|
||||
Label: fmt.Sprintf("Account profiles defined in %s", file),
|
||||
Items: profiles,
|
||||
Searcher: profiles.SearchCaseInsensitive,
|
||||
|
@ -182,9 +240,7 @@ func askForAccountProfile() (string, error) {
|
|||
Inactive: `{{.Name}}`,
|
||||
Selected: `{{ "Using account profile" | faint }}: {{ .Name | bold }}`,
|
||||
},
|
||||
Stdin: os.Stdin,
|
||||
Stdout: os.Stderr,
|
||||
}).Run()
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
@ -2,9 +2,15 @@ package root
|
|||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/databricks/cli/libs/cmdio"
|
||||
"github.com/databricks/databricks-sdk-go/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEmptyHttpRequest(t *testing.T) {
|
||||
|
@ -12,3 +18,161 @@ func TestEmptyHttpRequest(t *testing.T) {
|
|||
req := emptyHttpRequest(ctx)
|
||||
assert.Equal(t, req.Context(), ctx)
|
||||
}
|
||||
|
||||
type promptFn func(ctx context.Context, cfg *config.Config, retry bool) (any, error)
|
||||
|
||||
var accountPromptFn = func(ctx context.Context, cfg *config.Config, retry bool) (any, error) {
|
||||
return accountClientOrPrompt(ctx, cfg, retry)
|
||||
}
|
||||
|
||||
var workspacePromptFn = func(ctx context.Context, cfg *config.Config, retry bool) (any, error) {
|
||||
return workspaceClientOrPrompt(ctx, cfg, retry)
|
||||
}
|
||||
|
||||
func expectPrompts(t *testing.T, fn promptFn, config *config.Config) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Channel to pass errors from the prompting function back to the test.
|
||||
errch := make(chan error, 1)
|
||||
|
||||
ctx, io := cmdio.SetupTest(ctx)
|
||||
go func() {
|
||||
defer close(errch)
|
||||
defer cancel()
|
||||
_, err := fn(ctx, config, true)
|
||||
errch <- err
|
||||
}()
|
||||
|
||||
// Expect a prompt
|
||||
line, _, err := io.Stderr.ReadLine()
|
||||
if assert.NoError(t, err, "Expected to read a line from stderr") {
|
||||
assert.Contains(t, string(line), "Search:")
|
||||
} else {
|
||||
// If there was an error reading from stderr, the prompting function must have terminated early.
|
||||
assert.NoError(t, <-errch)
|
||||
}
|
||||
}
|
||||
|
||||
func expectReturns(t *testing.T, fn promptFn, config *config.Config) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ctx, _ = cmdio.SetupTest(ctx)
|
||||
client, err := fn(ctx, config, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
}
|
||||
|
||||
func TestAccountClientOrPrompt(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, ".databrickscfg")
|
||||
err := os.WriteFile(
|
||||
configFile,
|
||||
[]byte(`
|
||||
[account-1111]
|
||||
host = https://accounts.azuredatabricks.net/
|
||||
account_id = 1111
|
||||
token = foobar
|
||||
|
||||
[account-1112]
|
||||
host = https://accounts.azuredatabricks.net/
|
||||
account_id = 1112
|
||||
token = foobar
|
||||
`),
|
||||
0755)
|
||||
require.NoError(t, err)
|
||||
t.Setenv("DATABRICKS_CONFIG_FILE", configFile)
|
||||
t.Setenv("PATH", "/nothing")
|
||||
|
||||
t.Run("Prompt if nothing is specified", func(t *testing.T) {
|
||||
expectPrompts(t, accountPromptFn, &config.Config{})
|
||||
})
|
||||
|
||||
t.Run("Prompt if a workspace host is specified", func(t *testing.T) {
|
||||
expectPrompts(t, accountPromptFn, &config.Config{
|
||||
Host: "https://adb-1234567.89.azuredatabricks.net/",
|
||||
AccountID: "1234",
|
||||
Token: "foobar",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Prompt if account ID is not specified", func(t *testing.T) {
|
||||
expectPrompts(t, accountPromptFn, &config.Config{
|
||||
Host: "https://accounts.azuredatabricks.net/",
|
||||
Token: "foobar",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Prompt if no credential provider can be configured", func(t *testing.T) {
|
||||
expectPrompts(t, accountPromptFn, &config.Config{
|
||||
Host: "https://accounts.azuredatabricks.net/",
|
||||
AccountID: "1234",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Returns if configuration is valid", func(t *testing.T) {
|
||||
expectReturns(t, accountPromptFn, &config.Config{
|
||||
Host: "https://accounts.azuredatabricks.net/",
|
||||
AccountID: "1234",
|
||||
Token: "foobar",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Returns if a valid profile is specified", func(t *testing.T) {
|
||||
expectReturns(t, accountPromptFn, &config.Config{
|
||||
Profile: "account-1111",
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkspaceClientOrPrompt(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, ".databrickscfg")
|
||||
err := os.WriteFile(
|
||||
configFile,
|
||||
[]byte(`
|
||||
[workspace-1111]
|
||||
host = https://adb-1111.11.azuredatabricks.net/
|
||||
token = foobar
|
||||
|
||||
[workspace-1112]
|
||||
host = https://adb-1112.12.azuredatabricks.net/
|
||||
token = foobar
|
||||
`),
|
||||
0755)
|
||||
require.NoError(t, err)
|
||||
t.Setenv("DATABRICKS_CONFIG_FILE", configFile)
|
||||
t.Setenv("PATH", "/nothing")
|
||||
|
||||
t.Run("Prompt if nothing is specified", func(t *testing.T) {
|
||||
expectPrompts(t, workspacePromptFn, &config.Config{})
|
||||
})
|
||||
|
||||
t.Run("Prompt if an account host is specified", func(t *testing.T) {
|
||||
expectPrompts(t, workspacePromptFn, &config.Config{
|
||||
Host: "https://accounts.azuredatabricks.net/",
|
||||
AccountID: "1234",
|
||||
Token: "foobar",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Prompt if no credential provider can be configured", func(t *testing.T) {
|
||||
expectPrompts(t, workspacePromptFn, &config.Config{
|
||||
Host: "https://adb-1111.11.azuredatabricks.net/",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Returns if configuration is valid", func(t *testing.T) {
|
||||
expectReturns(t, workspacePromptFn, &config.Config{
|
||||
Host: "https://adb-1111.11.azuredatabricks.net/",
|
||||
Token: "foobar",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Returns if a valid profile is specified", func(t *testing.T) {
|
||||
expectReturns(t, workspacePromptFn, &config.Config{
|
||||
Profile: "workspace-1111",
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
|
@ -205,6 +205,13 @@ func Prompt(ctx context.Context) *promptui.Prompt {
|
|||
}
|
||||
}
|
||||
|
||||
func RunSelect(ctx context.Context, prompt *promptui.Select) (int, string, error) {
|
||||
c := fromContext(ctx)
|
||||
prompt.Stdin = io.NopCloser(c.in)
|
||||
prompt.Stdout = nopWriteCloser{c.err}
|
||||
return prompt.Run()
|
||||
}
|
||||
|
||||
func (c *cmdIO) simplePrompt(label string) *promptui.Prompt {
|
||||
return &promptui.Prompt{
|
||||
Label: label,
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
package cmdio
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Test struct {
|
||||
Done context.CancelFunc
|
||||
|
||||
Stdin *bufio.Writer
|
||||
Stdout *bufio.Reader
|
||||
Stderr *bufio.Reader
|
||||
}
|
||||
|
||||
func SetupTest(ctx context.Context) (context.Context, *Test) {
|
||||
rin, win := io.Pipe()
|
||||
rout, wout := io.Pipe()
|
||||
rerr, werr := io.Pipe()
|
||||
|
||||
cmdio := &cmdIO{
|
||||
interactive: true,
|
||||
in: rin,
|
||||
out: wout,
|
||||
err: werr,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
ctx = InContext(ctx, cmdio)
|
||||
|
||||
// Wait for context to be done, so we can drain stdin and close the pipes.
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
rin.Close()
|
||||
wout.Close()
|
||||
werr.Close()
|
||||
}()
|
||||
|
||||
return ctx, &Test{
|
||||
Done: cancel,
|
||||
Stdin: bufio.NewWriter(win),
|
||||
Stdout: bufio.NewReader(rout),
|
||||
Stderr: bufio.NewReader(rerr),
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue