Improve token refresh flow (#1434)

## Changes
Currently, there are a number of issues with the non-happy-path flows
for token refresh in the CLI.

If the token refresh fails, the raw error message is presented to the
user, as seen below. This message is very difficult for users to
interpret and doesn't give any clear direction on how to resolve this
issue.
```
Error: token refresh: Post "https://adb-<WSID>.azuredatabricks.net/oidc/v1/token": http 400: {"error":"invalid_request","error_description":"Refresh token is invalid"}
```

When logging in again, I've noticed that the timeout for logging in is
very short, only 45 seconds. If a user is using a password manager and
needs to login to that first, or needs to do MFA, 45 seconds may not be
enough time. to an account-level profile, it is quite frustrating for
users to need to re-enter account ID information when that information
is already stored in the user's `.databrickscfg` file.

This PR tackles these two issues. First, the presentation of error
messages from `databricks auth token` is improved substantially by
converting the `error` into a human-readable message. When the refresh
token is invalid, it will present a command for the user to run to
reauthenticate. If the token fetching failed for some other reason, that
reason will be presented in a nice way, providing front-line debugging
steps and ultimately redirecting users to file a ticket at this repo if
they can't resolve the issue themselves. After this PR, the new error
message is:
```
Error: a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run `.databricks/databricks auth login --host https://adb-<WSID>.azuredatabricks.net`
```

To improve the login flow, this PR modifies `databricks auth login` to
auto-complete the account ID from the profile when present.
Additionally, it increases the login timeout from 45 seconds to 1 hour
to give the user sufficient time to login as needed.

To test this change, I needed to refactor some components of the CLI
around profile management, the token cache, and the API client used to
fetch OAuth tokens. These are now settable in the context, and a
demonstration of how they can be set and used is found in
`auth_test.go`.

Separately, this also demonstrates a sort-of integration test of the CLI
by executing the Cobra command for `databricks auth token` from tests,
which may be useful for testing other end-to-end functionality in the
CLI. In particular, I believe this is necessary in order to set flag
values (like the `--profile` flag in this case) for use in testing.

## Tests
Unit tests cover the unhappy and happy paths using the mocked API
client, token cache, and profiler.

Manually tested

---------

Co-authored-by: Pieter Noordhuis <pieter.noordhuis@databricks.com>
This commit is contained in:
Miles Yucht 2024-05-16 12:22:09 +02:00 committed by GitHub
parent 157877a152
commit f7d4b272f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 743 additions and 323 deletions

View File

@ -1,6 +1,8 @@
package config_tests
import (
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
@ -9,12 +11,14 @@ import (
func TestGitAutoLoadWithEnvironment(t *testing.T) {
b := load(t, "./environments_autoload_git")
assert.True(t, b.Config.Bundle.Git.Inferred)
assert.Contains(t, b.Config.Bundle.Git.OriginURL, "/cli")
validUrl := strings.Contains(b.Config.Bundle.Git.OriginURL, "/cli") || strings.Contains(b.Config.Bundle.Git.OriginURL, "/bricks")
assert.True(t, validUrl, fmt.Sprintf("Expected URL to contain '/cli' or '/bricks', got %s", b.Config.Bundle.Git.OriginURL))
}
func TestGitManuallySetBranchWithEnvironment(t *testing.T) {
b := loadTarget(t, "./environments_autoload_git", "production")
assert.False(t, b.Config.Bundle.Git.Inferred)
assert.Equal(t, "main", b.Config.Bundle.Git.Branch)
assert.Contains(t, b.Config.Bundle.Git.OriginURL, "/cli")
validUrl := strings.Contains(b.Config.Bundle.Git.OriginURL, "/cli") || strings.Contains(b.Config.Bundle.Git.OriginURL, "/bricks")
assert.True(t, validUrl, fmt.Sprintf("Expected URL to contain '/cli' or '/bricks', got %s", b.Config.Bundle.Git.OriginURL))
}

View File

@ -2,6 +2,8 @@ package config_tests
import (
"context"
"fmt"
"strings"
"testing"
"github.com/databricks/cli/bundle"
@ -13,14 +15,16 @@ import (
func TestGitAutoLoad(t *testing.T) {
b := load(t, "./autoload_git")
assert.True(t, b.Config.Bundle.Git.Inferred)
assert.Contains(t, b.Config.Bundle.Git.OriginURL, "/cli")
validUrl := strings.Contains(b.Config.Bundle.Git.OriginURL, "/cli") || strings.Contains(b.Config.Bundle.Git.OriginURL, "/bricks")
assert.True(t, validUrl, fmt.Sprintf("Expected URL to contain '/cli' or '/bricks', got %s", b.Config.Bundle.Git.OriginURL))
}
func TestGitManuallySetBranch(t *testing.T) {
b := loadTarget(t, "./autoload_git", "production")
assert.False(t, b.Config.Bundle.Git.Inferred)
assert.Equal(t, "main", b.Config.Bundle.Git.Branch)
assert.Contains(t, b.Config.Bundle.Git.OriginURL, "/cli")
validUrl := strings.Contains(b.Config.Bundle.Git.OriginURL, "/cli") || strings.Contains(b.Config.Bundle.Git.OriginURL, "/bricks")
assert.True(t, validUrl, fmt.Sprintf("Expected URL to contain '/cli' or '/bricks', got %s", b.Config.Bundle.Git.OriginURL))
}
func TestGitBundleBranchValidation(t *testing.T) {

View File

@ -10,7 +10,7 @@ import (
"net/url"
"strings"
"github.com/databricks/cli/libs/databrickscfg"
"github.com/databricks/cli/libs/databrickscfg/profile"
"github.com/databricks/databricks-sdk-go/config"
"github.com/spf13/cobra"
"gopkg.in/ini.v1"
@ -70,7 +70,7 @@ func resolveSection(cfg *config.Config, iniFile *config.File) (*ini.Section, err
}
func loadFromDatabricksCfg(ctx context.Context, cfg *config.Config) error {
iniFile, err := databrickscfg.Get(ctx)
iniFile, err := profile.DefaultProfiler.Get(ctx)
if errors.Is(err, fs.ErrNotExist) {
// it's fine not to have ~/.databrickscfg
return nil

View File

@ -11,6 +11,7 @@ import (
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/databrickscfg"
"github.com/databricks/cli/libs/databrickscfg/cfgpickers"
"github.com/databricks/cli/libs/databrickscfg/profile"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/config"
"github.com/spf13/cobra"
@ -31,6 +32,7 @@ func configureHost(ctx context.Context, persistentAuth *auth.PersistentAuth, arg
}
const minimalDbConnectVersion = "13.1"
const defaultTimeout = 1 * time.Hour
func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
defaultConfigPath := "~/.databrickscfg"
@ -84,7 +86,7 @@ depends on the existing profiles you have set in your configuration file
var loginTimeout time.Duration
var configureCluster bool
cmd.Flags().DurationVar(&loginTimeout, "timeout", auth.DefaultTimeout,
cmd.Flags().DurationVar(&loginTimeout, "timeout", defaultTimeout,
"Timeout for completing login challenge in the browser")
cmd.Flags().BoolVar(&configureCluster, "configure-cluster", false,
"Prompts to configure cluster")
@ -108,7 +110,7 @@ depends on the existing profiles you have set in your configuration file
profileName = profile
}
err := setHost(ctx, profileName, persistentAuth, args)
err := setHostAndAccountId(ctx, profileName, persistentAuth, args)
if err != nil {
return err
}
@ -117,17 +119,10 @@ depends on the existing profiles you have set in your configuration file
// We need the config without the profile before it's used to initialise new workspace client below.
// Otherwise it will complain about non existing profile because it was not yet saved.
cfg := config.Config{
Host: persistentAuth.Host,
AuthType: "databricks-cli",
Host: persistentAuth.Host,
AccountID: persistentAuth.AccountID,
AuthType: "databricks-cli",
}
if cfg.IsAccountClient() && persistentAuth.AccountID == "" {
accountId, err := promptForAccountID(ctx)
if err != nil {
return err
}
persistentAuth.AccountID = accountId
}
cfg.AccountID = persistentAuth.AccountID
ctx, cancel := context.WithTimeout(ctx, loginTimeout)
defer cancel()
@ -172,15 +167,15 @@ depends on the existing profiles you have set in your configuration file
return cmd
}
func setHost(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error {
func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error {
profiler := profile.GetProfiler(ctx)
// If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile.
_, profiles, err := databrickscfg.LoadProfiles(ctx, func(p databrickscfg.Profile) bool {
return p.Name == profileName
})
profiles, err := profiler.LoadProfiles(ctx, profile.WithName(profileName))
// Tolerate ErrNoConfiguration here, as we will write out a configuration as part of the login flow.
if err != nil && !errors.Is(err, databrickscfg.ErrNoConfiguration) {
if err != nil && !errors.Is(err, profile.ErrNoConfiguration) {
return err
}
if persistentAuth.Host == "" {
if len(profiles) > 0 && profiles[0].Host != "" {
persistentAuth.Host = profiles[0].Host
@ -188,5 +183,17 @@ func setHost(ctx context.Context, profileName string, persistentAuth *auth.Persi
configureHost(ctx, persistentAuth, args, 0)
}
}
isAccountClient := (&config.Config{Host: persistentAuth.Host}).IsAccountClient()
if isAccountClient && persistentAuth.AccountID == "" {
if len(profiles) > 0 && profiles[0].AccountID != "" {
persistentAuth.AccountID = profiles[0].AccountID
} else {
accountId, err := promptForAccountID(ctx)
if err != nil {
return err
}
persistentAuth.AccountID = accountId
}
}
return nil
}

View File

@ -12,6 +12,6 @@ import (
func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) {
ctx := context.Background()
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./imaginary-file/databrickscfg")
err := setHost(ctx, "foo", &auth.PersistentAuth{Host: "test"}, []string{})
err := setHostAndAccountId(ctx, "foo", &auth.PersistentAuth{Host: "test"}, []string{})
assert.NoError(t, err)
}

View File

@ -8,7 +8,7 @@ import (
"time"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/databrickscfg"
"github.com/databricks/cli/libs/databrickscfg/profile"
"github.com/databricks/cli/libs/log"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/config"
@ -94,7 +94,7 @@ func newProfilesCommand() *cobra.Command {
cmd.RunE = func(cmd *cobra.Command, args []string) error {
var profiles []*profileMetadata
iniFile, err := databrickscfg.Get(cmd.Context())
iniFile, err := profile.DefaultProfiler.Get(cmd.Context())
if os.IsNotExist(err) {
// return empty list for non-configured machines
iniFile = &config.File{

View File

@ -4,12 +4,44 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"strings"
"time"
"github.com/databricks/cli/libs/auth"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/spf13/cobra"
)
type tokenErrorResponse struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}
func buildLoginCommand(profile string, persistentAuth *auth.PersistentAuth) string {
executable := os.Args[0]
cmd := []string{
executable,
"auth",
"login",
}
if profile != "" {
cmd = append(cmd, "--profile", profile)
} else {
cmd = append(cmd, "--host", persistentAuth.Host)
if persistentAuth.AccountID != "" {
cmd = append(cmd, "--account-id", persistentAuth.AccountID)
}
}
return strings.Join(cmd, " ")
}
func helpfulError(profile string, persistentAuth *auth.PersistentAuth) string {
loginMsg := buildLoginCommand(profile, persistentAuth)
return fmt.Sprintf("Try logging in again with `%s` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", loginMsg)
}
func newTokenCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
cmd := &cobra.Command{
Use: "token [HOST]",
@ -17,7 +49,7 @@ func newTokenCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
}
var tokenTimeout time.Duration
cmd.Flags().DurationVar(&tokenTimeout, "timeout", auth.DefaultTimeout,
cmd.Flags().DurationVar(&tokenTimeout, "timeout", defaultTimeout,
"Timeout for acquiring a token.")
cmd.RunE = func(cmd *cobra.Command, args []string) error {
@ -29,11 +61,11 @@ func newTokenCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
profileName = profileFlag.Value.String()
// If a profile is provided we read the host from the .databrickscfg file
if profileName != "" && len(args) > 0 {
return errors.New("providing both a profile and a host parameters is not supported")
return errors.New("providing both a profile and host is not supported")
}
}
err := setHost(ctx, profileName, persistentAuth, args)
err := setHostAndAccountId(ctx, profileName, persistentAuth, args)
if err != nil {
return err
}
@ -42,8 +74,21 @@ func newTokenCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
ctx, cancel := context.WithTimeout(ctx, tokenTimeout)
defer cancel()
t, err := persistentAuth.Load(ctx)
if err != nil {
return err
var httpErr *httpclient.HttpError
if errors.As(err, &httpErr) {
helpMsg := helpfulError(profileName, persistentAuth)
t := &tokenErrorResponse{}
err = json.Unmarshal([]byte(httpErr.Message), t)
if err != nil {
return fmt.Errorf("unexpected parsing token response: %w. %s", err, helpMsg)
}
if t.ErrorDescription == "Refresh token is invalid" {
return fmt.Errorf("a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run `%s`", buildLoginCommand(profileName, persistentAuth))
} else {
return fmt.Errorf("unexpected error refreshing token: %s. %s", t.ErrorDescription, helpMsg)
}
} else if err != nil {
return fmt.Errorf("unexpected error refreshing token: %w. %s", err, helpfulError(profileName, persistentAuth))
}
raw, err := json.MarshalIndent(t, "", " ")
if err != nil {

168
cmd/auth/token_test.go Normal file
View File

@ -0,0 +1,168 @@
package auth_test
import (
"bytes"
"context"
"encoding/json"
"testing"
"time"
"github.com/databricks/cli/cmd"
"github.com/databricks/cli/libs/auth"
"github.com/databricks/cli/libs/auth/cache"
"github.com/databricks/cli/libs/databrickscfg/profile"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"golang.org/x/oauth2"
)
var refreshFailureTokenResponse = fixtures.HTTPFixture{
MatchAny: true,
Status: 401,
Response: map[string]string{
"error": "invalid_request",
"error_description": "Refresh token is invalid",
},
}
var refreshFailureInvalidResponse = fixtures.HTTPFixture{
MatchAny: true,
Status: 401,
Response: "Not json",
}
var refreshFailureOtherError = fixtures.HTTPFixture{
MatchAny: true,
Status: 401,
Response: map[string]string{
"error": "other_error",
"error_description": "Databricks is down",
},
}
var refreshSuccessTokenResponse = fixtures.HTTPFixture{
MatchAny: true,
Status: 200,
Response: map[string]string{
"access_token": "new-access-token",
"token_type": "Bearer",
"expires_in": "3600",
},
}
func validateToken(t *testing.T, resp string) {
res := map[string]string{}
err := json.Unmarshal([]byte(resp), &res)
assert.NoError(t, err)
assert.Equal(t, "new-access-token", res["access_token"])
assert.Equal(t, "Bearer", res["token_type"])
}
func getContextForTest(f fixtures.HTTPFixture) context.Context {
profiler := profile.InMemoryProfiler{
Profiles: profile.Profiles{
{
Name: "expired",
Host: "https://accounts.cloud.databricks.com",
AccountID: "expired",
},
{
Name: "active",
Host: "https://accounts.cloud.databricks.com",
AccountID: "active",
},
},
}
tokenCache := &cache.InMemoryTokenCache{
Tokens: map[string]*oauth2.Token{
"https://accounts.cloud.databricks.com/oidc/accounts/expired": {
RefreshToken: "expired",
},
"https://accounts.cloud.databricks.com/oidc/accounts/active": {
RefreshToken: "active",
Expiry: time.Now().Add(1 * time.Hour), // Hopefully unit tests don't take an hour to run
},
},
}
client := httpclient.NewApiClient(httpclient.ClientConfig{
Transport: fixtures.SliceTransport{f},
})
ctx := profile.WithProfiler(context.Background(), profiler)
ctx = cache.WithTokenCache(ctx, tokenCache)
ctx = auth.WithApiClientForOAuth(ctx, client)
return ctx
}
func getCobraCmdForTest(f fixtures.HTTPFixture) (*cobra.Command, *bytes.Buffer) {
ctx := getContextForTest(f)
c := cmd.New(ctx)
output := &bytes.Buffer{}
c.SetOut(output)
return c, output
}
func TestTokenCmdWithProfilePrintsHelpfulLoginMessageOnRefreshFailure(t *testing.T) {
cmd, output := getCobraCmdForTest(refreshFailureTokenResponse)
cmd.SetArgs([]string{"auth", "token", "--profile", "expired"})
err := cmd.Execute()
out := output.String()
assert.Empty(t, out)
assert.ErrorContains(t, err, "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run ")
assert.ErrorContains(t, err, "auth login --profile expired")
}
func TestTokenCmdWithHostPrintsHelpfulLoginMessageOnRefreshFailure(t *testing.T) {
cmd, output := getCobraCmdForTest(refreshFailureTokenResponse)
cmd.SetArgs([]string{"auth", "token", "--host", "https://accounts.cloud.databricks.com", "--account-id", "expired"})
err := cmd.Execute()
out := output.String()
assert.Empty(t, out)
assert.ErrorContains(t, err, "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run ")
assert.ErrorContains(t, err, "auth login --host https://accounts.cloud.databricks.com --account-id expired")
}
func TestTokenCmdInvalidResponse(t *testing.T) {
cmd, output := getCobraCmdForTest(refreshFailureInvalidResponse)
cmd.SetArgs([]string{"auth", "token", "--profile", "active"})
err := cmd.Execute()
out := output.String()
assert.Empty(t, out)
assert.ErrorContains(t, err, "unexpected parsing token response: invalid character 'N' looking for beginning of value. Try logging in again with ")
assert.ErrorContains(t, err, "auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new")
}
func TestTokenCmdOtherErrorResponse(t *testing.T) {
cmd, output := getCobraCmdForTest(refreshFailureOtherError)
cmd.SetArgs([]string{"auth", "token", "--profile", "active"})
err := cmd.Execute()
out := output.String()
assert.Empty(t, out)
assert.ErrorContains(t, err, "unexpected error refreshing token: Databricks is down. Try logging in again with ")
assert.ErrorContains(t, err, "auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new")
}
func TestTokenCmdWithProfileSuccess(t *testing.T) {
cmd, output := getCobraCmdForTest(refreshSuccessTokenResponse)
cmd.SetArgs([]string{"auth", "token", "--profile", "active"})
err := cmd.Execute()
out := output.String()
validateToken(t, out)
assert.NoError(t, err)
}
func TestTokenCmdWithHostSuccess(t *testing.T) {
cmd, output := getCobraCmdForTest(refreshSuccessTokenResponse)
cmd.SetArgs([]string{"auth", "token", "--host", "https://accounts.cloud.databricks.com", "--account-id", "expired"})
err := cmd.Execute()
out := output.String()
validateToken(t, out)
assert.NoError(t, err)
}

View File

@ -11,8 +11,8 @@ import (
"github.com/databricks/cli/cmd/labs/github"
"github.com/databricks/cli/cmd/labs/unpack"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/databrickscfg"
"github.com/databricks/cli/libs/databrickscfg/cfgpickers"
"github.com/databricks/cli/libs/databrickscfg/profile"
"github.com/databricks/cli/libs/log"
"github.com/databricks/cli/libs/process"
"github.com/databricks/cli/libs/python"
@ -89,7 +89,7 @@ func (i *installer) Install(ctx context.Context) error {
return err
}
w, err := i.login(ctx)
if err != nil && errors.Is(err, databrickscfg.ErrNoConfiguration) {
if err != nil && errors.Is(err, profile.ErrNoConfiguration) {
cfg, err := i.Installer.envAwareConfig(ctx)
if err != nil {
return err

View File

@ -7,7 +7,7 @@ import (
"net/http"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/databrickscfg"
"github.com/databricks/cli/libs/databrickscfg/profile"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/config"
"github.com/manifoldco/promptui"
@ -37,7 +37,7 @@ func (e ErrNoAccountProfiles) Error() string {
func initProfileFlag(cmd *cobra.Command) {
cmd.PersistentFlags().StringP("profile", "p", "", "~/.databrickscfg profile")
cmd.RegisterFlagCompletionFunc("profile", databrickscfg.ProfileCompletion)
cmd.RegisterFlagCompletionFunc("profile", profile.ProfileCompletion)
}
func profileFlagValue(cmd *cobra.Command) (string, bool) {
@ -111,27 +111,29 @@ func MustAccountClient(cmd *cobra.Command, args []string) error {
cfg := &config.Config{}
// The command-line profile flag takes precedence over DATABRICKS_CONFIG_PROFILE.
profile, hasProfileFlag := profileFlagValue(cmd)
pr, hasProfileFlag := profileFlagValue(cmd)
if hasProfileFlag {
cfg.Profile = profile
cfg.Profile = pr
}
ctx := cmd.Context()
ctx = context.WithValue(ctx, &configUsed, cfg)
cmd.SetContext(ctx)
profiler := profile.GetProfiler(ctx)
if cfg.Profile == "" {
// account-level CLI was not really done before, so here are the assumptions:
// 1. only admins will have account configured
// 2. 99% of admins will have access to just one account
// hence, we don't need to create a special "DEFAULT_ACCOUNT" profile yet
_, profiles, err := databrickscfg.LoadProfiles(cmd.Context(), databrickscfg.MatchAccountProfiles)
profiles, err := profiler.LoadProfiles(cmd.Context(), profile.MatchAccountProfiles)
if err == nil && len(profiles) == 1 {
cfg.Profile = profiles[0].Name
}
// if there is no config file, we don't want to fail and instead just skip it
if err != nil && !errors.Is(err, databrickscfg.ErrNoConfiguration) {
if err != nil && !errors.Is(err, profile.ErrNoConfiguration) {
return err
}
}
@ -233,11 +235,12 @@ func SetAccountClient(ctx context.Context, a *databricks.AccountClient) context.
}
func AskForWorkspaceProfile(ctx context.Context) (string, error) {
path, err := databrickscfg.GetPath(ctx)
profiler := profile.GetProfiler(ctx)
path, err := profiler.GetPath(ctx)
if err != nil {
return "", fmt.Errorf("cannot determine Databricks config file path: %w", err)
}
file, profiles, err := databrickscfg.LoadProfiles(ctx, databrickscfg.MatchWorkspaceProfiles)
profiles, err := profiler.LoadProfiles(ctx, profile.MatchWorkspaceProfiles)
if err != nil {
return "", err
}
@ -248,7 +251,7 @@ func AskForWorkspaceProfile(ctx context.Context) (string, error) {
return profiles[0].Name, nil
}
i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
Label: fmt.Sprintf("Workspace profiles defined in %s", file),
Label: fmt.Sprintf("Workspace profiles defined in %s", path),
Items: profiles,
Searcher: profiles.SearchCaseInsensitive,
StartInSearchMode: true,
@ -266,11 +269,12 @@ func AskForWorkspaceProfile(ctx context.Context) (string, error) {
}
func AskForAccountProfile(ctx context.Context) (string, error) {
path, err := databrickscfg.GetPath(ctx)
profiler := profile.GetProfiler(ctx)
path, err := profiler.GetPath(ctx)
if err != nil {
return "", fmt.Errorf("cannot determine Databricks config file path: %w", err)
}
file, profiles, err := databrickscfg.LoadProfiles(ctx, databrickscfg.MatchAccountProfiles)
profiles, err := profiler.LoadProfiles(ctx, profile.MatchAccountProfiles)
if err != nil {
return "", err
}
@ -281,7 +285,7 @@ func AskForAccountProfile(ctx context.Context) (string, error) {
return profiles[0].Name, nil
}
i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
Label: fmt.Sprintf("Account profiles defined in %s", file),
Label: fmt.Sprintf("Account profiles defined in %s", path),
Items: profiles,
Searcher: profiles.SearchCaseInsensitive,
StartInSearchMode: true,

View File

@ -1,106 +1,26 @@
package cache
import (
"encoding/json"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"context"
"golang.org/x/oauth2"
)
const (
// where the token cache is stored
tokenCacheFile = ".databricks/token-cache.json"
// only the owner of the file has full execute, read, and write access
ownerExecReadWrite = 0o700
// only the owner of the file has full read and write access
ownerReadWrite = 0o600
// format versioning leaves some room for format improvement
tokenCacheVersion = 1
)
var ErrNotConfigured = errors.New("databricks OAuth is not configured for this host")
// this implementation requires the calling code to do a machine-wide lock,
// otherwise the file might get corrupt.
type TokenCache struct {
Version int `json:"version"`
Tokens map[string]*oauth2.Token `json:"tokens"`
fileLocation string
type TokenCache interface {
Store(key string, t *oauth2.Token) error
Lookup(key string) (*oauth2.Token, error)
}
func (c *TokenCache) Store(key string, t *oauth2.Token) error {
err := c.load()
if errors.Is(err, fs.ErrNotExist) {
dir := filepath.Dir(c.fileLocation)
err = os.MkdirAll(dir, ownerExecReadWrite)
if err != nil {
return fmt.Errorf("mkdir: %w", err)
}
} else if err != nil {
return fmt.Errorf("load: %w", err)
}
c.Version = tokenCacheVersion
if c.Tokens == nil {
c.Tokens = map[string]*oauth2.Token{}
}
c.Tokens[key] = t
raw, err := json.MarshalIndent(c, "", " ")
if err != nil {
return fmt.Errorf("marshal: %w", err)
}
return os.WriteFile(c.fileLocation, raw, ownerReadWrite)
var tokenCache int
func WithTokenCache(ctx context.Context, c TokenCache) context.Context {
return context.WithValue(ctx, &tokenCache, c)
}
func (c *TokenCache) Lookup(key string) (*oauth2.Token, error) {
err := c.load()
if errors.Is(err, fs.ErrNotExist) {
return nil, ErrNotConfigured
} else if err != nil {
return nil, fmt.Errorf("load: %w", err)
}
t, ok := c.Tokens[key]
func GetTokenCache(ctx context.Context) TokenCache {
c, ok := ctx.Value(&tokenCache).(TokenCache)
if !ok {
return nil, ErrNotConfigured
return &FileTokenCache{}
}
return t, nil
}
func (c *TokenCache) location() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("home: %w", err)
}
return filepath.Join(home, tokenCacheFile), nil
}
func (c *TokenCache) load() error {
loc, err := c.location()
if err != nil {
return err
}
c.fileLocation = loc
raw, err := os.ReadFile(loc)
if err != nil {
return fmt.Errorf("read: %w", err)
}
err = json.Unmarshal(raw, c)
if err != nil {
return fmt.Errorf("parse: %w", err)
}
if c.Version != tokenCacheVersion {
// in the later iterations we could do state upgraders,
// so that we transform token cache from v1 to v2 without
// losing the tokens and asking the user to re-authenticate.
return fmt.Errorf("needs version %d, got version %d",
tokenCacheVersion, c.Version)
}
return nil
return c
}

108
libs/auth/cache/file.go vendored Normal file
View File

@ -0,0 +1,108 @@
package cache
import (
"encoding/json"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"golang.org/x/oauth2"
)
const (
// where the token cache is stored
tokenCacheFile = ".databricks/token-cache.json"
// only the owner of the file has full execute, read, and write access
ownerExecReadWrite = 0o700
// only the owner of the file has full read and write access
ownerReadWrite = 0o600
// format versioning leaves some room for format improvement
tokenCacheVersion = 1
)
var ErrNotConfigured = errors.New("databricks OAuth is not configured for this host")
// this implementation requires the calling code to do a machine-wide lock,
// otherwise the file might get corrupt.
type FileTokenCache struct {
Version int `json:"version"`
Tokens map[string]*oauth2.Token `json:"tokens"`
fileLocation string
}
func (c *FileTokenCache) Store(key string, t *oauth2.Token) error {
err := c.load()
if errors.Is(err, fs.ErrNotExist) {
dir := filepath.Dir(c.fileLocation)
err = os.MkdirAll(dir, ownerExecReadWrite)
if err != nil {
return fmt.Errorf("mkdir: %w", err)
}
} else if err != nil {
return fmt.Errorf("load: %w", err)
}
c.Version = tokenCacheVersion
if c.Tokens == nil {
c.Tokens = map[string]*oauth2.Token{}
}
c.Tokens[key] = t
raw, err := json.MarshalIndent(c, "", " ")
if err != nil {
return fmt.Errorf("marshal: %w", err)
}
return os.WriteFile(c.fileLocation, raw, ownerReadWrite)
}
func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) {
err := c.load()
if errors.Is(err, fs.ErrNotExist) {
return nil, ErrNotConfigured
} else if err != nil {
return nil, fmt.Errorf("load: %w", err)
}
t, ok := c.Tokens[key]
if !ok {
return nil, ErrNotConfigured
}
return t, nil
}
func (c *FileTokenCache) location() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("home: %w", err)
}
return filepath.Join(home, tokenCacheFile), nil
}
func (c *FileTokenCache) load() error {
loc, err := c.location()
if err != nil {
return err
}
c.fileLocation = loc
raw, err := os.ReadFile(loc)
if err != nil {
return fmt.Errorf("read: %w", err)
}
err = json.Unmarshal(raw, c)
if err != nil {
return fmt.Errorf("parse: %w", err)
}
if c.Version != tokenCacheVersion {
// in the later iterations we could do state upgraders,
// so that we transform token cache from v1 to v2 without
// losing the tokens and asking the user to re-authenticate.
return fmt.Errorf("needs version %d, got version %d",
tokenCacheVersion, c.Version)
}
return nil
}
var _ TokenCache = (*FileTokenCache)(nil)

View File

@ -27,7 +27,7 @@ func setup(t *testing.T) string {
func TestStoreAndLookup(t *testing.T) {
setup(t)
c := &TokenCache{}
c := &FileTokenCache{}
err := c.Store("x", &oauth2.Token{
AccessToken: "abc",
})
@ -38,7 +38,7 @@ func TestStoreAndLookup(t *testing.T) {
})
require.NoError(t, err)
l := &TokenCache{}
l := &FileTokenCache{}
tok, err := l.Lookup("x")
require.NoError(t, err)
assert.Equal(t, "abc", tok.AccessToken)
@ -50,7 +50,7 @@ func TestStoreAndLookup(t *testing.T) {
func TestNoCacheFileReturnsErrNotConfigured(t *testing.T) {
setup(t)
l := &TokenCache{}
l := &FileTokenCache{}
_, err := l.Lookup("x")
assert.Equal(t, ErrNotConfigured, err)
}
@ -63,7 +63,7 @@ func TestLoadCorruptFile(t *testing.T) {
err = os.WriteFile(f, []byte("abc"), ownerExecReadWrite)
require.NoError(t, err)
l := &TokenCache{}
l := &FileTokenCache{}
_, err = l.Lookup("x")
assert.EqualError(t, err, "load: parse: invalid character 'a' looking for beginning of value")
}
@ -76,14 +76,14 @@ func TestLoadWrongVersion(t *testing.T) {
err = os.WriteFile(f, []byte(`{"version": 823, "things": []}`), ownerExecReadWrite)
require.NoError(t, err)
l := &TokenCache{}
l := &FileTokenCache{}
_, err = l.Lookup("x")
assert.EqualError(t, err, "load: needs version 1, got version 823")
}
func TestDevNull(t *testing.T) {
t.Setenv(homeEnvVar, "/dev/null")
l := &TokenCache{}
l := &FileTokenCache{}
_, err := l.Lookup("x")
// macOS/Linux: load: read: open /dev/null/.databricks/token-cache.json:
// windows: databricks OAuth is not configured for this host
@ -95,7 +95,7 @@ func TestStoreOnDev(t *testing.T) {
t.SkipNow()
}
t.Setenv(homeEnvVar, "/dev")
c := &TokenCache{}
c := &FileTokenCache{}
err := c.Store("x", &oauth2.Token{
AccessToken: "abc",
})

26
libs/auth/cache/in_memory.go vendored Normal file
View File

@ -0,0 +1,26 @@
package cache
import (
"golang.org/x/oauth2"
)
type InMemoryTokenCache struct {
Tokens map[string]*oauth2.Token
}
// Lookup implements TokenCache.
func (i *InMemoryTokenCache) Lookup(key string) (*oauth2.Token, error) {
token, ok := i.Tokens[key]
if !ok {
return nil, ErrNotConfigured
}
return token, nil
}
// Store implements TokenCache.
func (i *InMemoryTokenCache) Store(key string, t *oauth2.Token) error {
i.Tokens[key] = t
return nil
}
var _ TokenCache = (*InMemoryTokenCache)(nil)

44
libs/auth/cache/in_memory_test.go vendored Normal file
View File

@ -0,0 +1,44 @@
package cache
import (
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/oauth2"
)
func TestInMemoryCacheHit(t *testing.T) {
token := &oauth2.Token{
AccessToken: "abc",
}
c := &InMemoryTokenCache{
Tokens: map[string]*oauth2.Token{
"key": token,
},
}
res, err := c.Lookup("key")
assert.Equal(t, res, token)
assert.NoError(t, err)
}
func TestInMemoryCacheMiss(t *testing.T) {
c := &InMemoryTokenCache{
Tokens: map[string]*oauth2.Token{},
}
_, err := c.Lookup("key")
assert.ErrorIs(t, err, ErrNotConfigured)
}
func TestInMemoryCacheStore(t *testing.T) {
token := &oauth2.Token{
AccessToken: "abc",
}
c := &InMemoryTokenCache{
Tokens: map[string]*oauth2.Token{},
}
err := c.Store("key", token)
assert.NoError(t, err)
res, err := c.Lookup("key")
assert.Equal(t, res, token)
assert.NoError(t, err)
}

View File

@ -20,6 +20,20 @@ import (
"golang.org/x/oauth2/authhandler"
)
var apiClientForOauth int
func WithApiClientForOAuth(ctx context.Context, c *httpclient.ApiClient) context.Context {
return context.WithValue(ctx, &apiClientForOauth, c)
}
func GetApiClientForOAuth(ctx context.Context) *httpclient.ApiClient {
c, ok := ctx.Value(&apiClientForOauth).(*httpclient.ApiClient)
if !ok {
return httpclient.NewApiClient(httpclient.ClientConfig{})
}
return c
}
const (
// these values are predefined by Databricks as a public client
// and is specific to this application only. Using these values
@ -28,7 +42,7 @@ const (
appRedirectAddr = "localhost:8020"
// maximum amount of time to acquire listener on appRedirectAddr
DefaultTimeout = 45 * time.Second
listenerTimeout = 45 * time.Second
)
var ( // Databricks SDK API: `databricks OAuth is not` will be checked for presence
@ -42,14 +56,13 @@ type PersistentAuth struct {
AccountID string
http *httpclient.ApiClient
cache tokenCache
cache cache.TokenCache
ln net.Listener
browser func(string) error
}
type tokenCache interface {
Store(key string, t *oauth2.Token) error
Lookup(key string) (*oauth2.Token, error)
func (a *PersistentAuth) SetApiClient(h *httpclient.ApiClient) {
a.http = h
}
func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) {
@ -136,12 +149,10 @@ func (a *PersistentAuth) init(ctx context.Context) error {
return ErrFetchCredentials
}
if a.http == nil {
a.http = httpclient.NewApiClient(httpclient.ClientConfig{
// noop
})
a.http = GetApiClientForOAuth(ctx)
}
if a.cache == nil {
a.cache = &cache.TokenCache{}
a.cache = cache.GetTokenCache(ctx)
}
if a.browser == nil {
a.browser = browser.OpenURL
@ -149,7 +160,7 @@ func (a *PersistentAuth) init(ctx context.Context) error {
// try acquire listener, which we also use as a machine-local
// exclusive lock to prevent token cache corruption in the scope
// of developer machine, where this command runs.
listener, err := retries.Poll(ctx, DefaultTimeout,
listener, err := retries.Poll(ctx, listenerTimeout,
func() (*net.Listener, *retries.Err) {
var lc net.ListenConfig
l, err := lc.Listen(ctx, "tcp", appRedirectAddr)

View File

@ -68,7 +68,7 @@ func TestLoaderErrorsOnInvalidFile(t *testing.T) {
Loaders: []config.Loader{
ResolveProfileFromHost,
},
ConfigFile: "testdata/badcfg",
ConfigFile: "profile/testdata/badcfg",
Host: "https://default",
}
@ -81,7 +81,7 @@ func TestLoaderSkipsNoMatchingHost(t *testing.T) {
Loaders: []config.Loader{
ResolveProfileFromHost,
},
ConfigFile: "testdata/databrickscfg",
ConfigFile: "profile/testdata/databrickscfg",
Host: "https://noneofthehostsmatch",
}
@ -95,7 +95,7 @@ func TestLoaderMatchingHost(t *testing.T) {
Loaders: []config.Loader{
ResolveProfileFromHost,
},
ConfigFile: "testdata/databrickscfg",
ConfigFile: "profile/testdata/databrickscfg",
Host: "https://default",
}
@ -110,7 +110,7 @@ func TestLoaderMatchingHostWithQuery(t *testing.T) {
Loaders: []config.Loader{
ResolveProfileFromHost,
},
ConfigFile: "testdata/databrickscfg",
ConfigFile: "profile/testdata/databrickscfg",
Host: "https://query/?foo=bar",
}
@ -125,7 +125,7 @@ func TestLoaderErrorsOnMultipleMatches(t *testing.T) {
Loaders: []config.Loader{
ResolveProfileFromHost,
},
ConfigFile: "testdata/databrickscfg",
ConfigFile: "profile/testdata/databrickscfg",
Host: "https://foo/bar",
}

View File

@ -30,7 +30,7 @@ func TestLoadOrCreate_NotAllowed(t *testing.T) {
}
func TestLoadOrCreate_Bad(t *testing.T) {
path := "testdata/badcfg"
path := "profile/testdata/badcfg"
file, err := loadOrCreateConfigFile(path)
assert.Error(t, err)
assert.Nil(t, file)
@ -40,7 +40,7 @@ func TestMatchOrCreateSection_Direct(t *testing.T) {
cfg := &config.Config{
Profile: "query",
}
file, err := loadOrCreateConfigFile("testdata/databrickscfg")
file, err := loadOrCreateConfigFile("profile/testdata/databrickscfg")
assert.NoError(t, err)
ctx := context.Background()
@ -54,7 +54,7 @@ func TestMatchOrCreateSection_AccountID(t *testing.T) {
cfg := &config.Config{
AccountID: "abc",
}
file, err := loadOrCreateConfigFile("testdata/databrickscfg")
file, err := loadOrCreateConfigFile("profile/testdata/databrickscfg")
assert.NoError(t, err)
ctx := context.Background()
@ -68,7 +68,7 @@ func TestMatchOrCreateSection_NormalizeHost(t *testing.T) {
cfg := &config.Config{
Host: "https://query/?o=abracadabra",
}
file, err := loadOrCreateConfigFile("testdata/databrickscfg")
file, err := loadOrCreateConfigFile("profile/testdata/databrickscfg")
assert.NoError(t, err)
ctx := context.Background()
@ -80,7 +80,7 @@ func TestMatchOrCreateSection_NormalizeHost(t *testing.T) {
func TestMatchOrCreateSection_NoProfileOrHost(t *testing.T) {
cfg := &config.Config{}
file, err := loadOrCreateConfigFile("testdata/databrickscfg")
file, err := loadOrCreateConfigFile("profile/testdata/databrickscfg")
assert.NoError(t, err)
ctx := context.Background()
@ -92,7 +92,7 @@ func TestMatchOrCreateSection_MultipleProfiles(t *testing.T) {
cfg := &config.Config{
Host: "https://foo",
}
file, err := loadOrCreateConfigFile("testdata/databrickscfg")
file, err := loadOrCreateConfigFile("profile/testdata/databrickscfg")
assert.NoError(t, err)
ctx := context.Background()
@ -105,7 +105,7 @@ func TestMatchOrCreateSection_NewProfile(t *testing.T) {
Host: "https://bar",
Profile: "delirium",
}
file, err := loadOrCreateConfigFile("testdata/databrickscfg")
file, err := loadOrCreateConfigFile("profile/testdata/databrickscfg")
assert.NoError(t, err)
ctx := context.Background()

View File

@ -0,0 +1,17 @@
package profile
import "context"
var profiler int
func WithProfiler(ctx context.Context, p Profiler) context.Context {
return context.WithValue(ctx, &profiler, p)
}
func GetProfiler(ctx context.Context) Profiler {
p, ok := ctx.Value(&profiler).(Profiler)
if !ok {
return DefaultProfiler
}
return p
}

View File

@ -0,0 +1,100 @@
package profile
import (
"context"
"errors"
"fmt"
"io/fs"
"path/filepath"
"strings"
"github.com/databricks/cli/libs/env"
"github.com/databricks/databricks-sdk-go/config"
"github.com/spf13/cobra"
)
type FileProfilerImpl struct{}
func (f FileProfilerImpl) getPath(ctx context.Context, replaceHomeDirWithTilde bool) (string, error) {
configFile := env.Get(ctx, "DATABRICKS_CONFIG_FILE")
if configFile == "" {
configFile = "~/.databrickscfg"
}
if !replaceHomeDirWithTilde {
return configFile, nil
}
homedir, err := env.UserHomeDir(ctx)
if err != nil {
return "", err
}
configFile = strings.Replace(configFile, homedir, "~", 1)
return configFile, nil
}
// Get the path to the .databrickscfg file, falling back to the default in the current user's home directory.
func (f FileProfilerImpl) GetPath(ctx context.Context) (string, error) {
fp, err := f.getPath(ctx, true)
if err != nil {
return "", err
}
return filepath.Clean(fp), nil
}
var ErrNoConfiguration = errors.New("no configuration file found")
func (f FileProfilerImpl) Get(ctx context.Context) (*config.File, error) {
path, err := f.getPath(ctx, false)
if err != nil {
return nil, fmt.Errorf("cannot determine Databricks config file path: %w", err)
}
if strings.HasPrefix(path, "~") {
homedir, err := env.UserHomeDir(ctx)
if err != nil {
return nil, err
}
path = filepath.Join(homedir, path[1:])
}
configFile, err := config.LoadFile(path)
if errors.Is(err, fs.ErrNotExist) {
// downstreams depend on ErrNoConfiguration. TODO: expose this error through SDK
return nil, fmt.Errorf("%w at %s; please create one by running 'databricks configure'", ErrNoConfiguration, path)
} else if err != nil {
return nil, err
}
return configFile, nil
}
func (f FileProfilerImpl) LoadProfiles(ctx context.Context, fn ProfileMatchFunction) (profiles Profiles, err error) {
file, err := f.Get(ctx)
if err != nil {
return nil, fmt.Errorf("cannot load Databricks config file: %w", err)
}
// Iterate over sections and collect matching profiles.
for _, v := range file.Sections() {
all := v.KeysHash()
host, ok := all["host"]
if !ok {
// invalid profile
continue
}
profile := Profile{
Name: v.Name(),
Host: host,
AccountID: all["account_id"],
}
if fn(profile) {
profiles = append(profiles, profile)
}
}
return
}
func ProfileCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
profiles, err := DefaultProfiler.LoadProfiles(cmd.Context(), MatchAllProfiles)
if err != nil {
return nil, cobra.ShellCompDirectiveError
}
return profiles.Names(), cobra.ShellCompDirectiveNoFileComp
}

View File

@ -1,4 +1,4 @@
package databrickscfg
package profile
import (
"context"
@ -32,7 +32,8 @@ func TestLoadProfilesReturnsHomedirAsTilde(t *testing.T) {
ctx := context.Background()
ctx = env.WithUserHomeDir(ctx, "testdata")
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg")
file, _, err := LoadProfiles(ctx, func(p Profile) bool { return true })
profiler := FileProfilerImpl{}
file, err := profiler.GetPath(ctx)
require.NoError(t, err)
require.Equal(t, filepath.Clean("~/databrickscfg"), file)
}
@ -41,7 +42,8 @@ func TestLoadProfilesReturnsHomedirAsTildeExoticFile(t *testing.T) {
ctx := context.Background()
ctx = env.WithUserHomeDir(ctx, "testdata")
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "~/databrickscfg")
file, _, err := LoadProfiles(ctx, func(p Profile) bool { return true })
profiler := FileProfilerImpl{}
file, err := profiler.GetPath(ctx)
require.NoError(t, err)
require.Equal(t, filepath.Clean("~/databrickscfg"), file)
}
@ -49,7 +51,8 @@ func TestLoadProfilesReturnsHomedirAsTildeExoticFile(t *testing.T) {
func TestLoadProfilesReturnsHomedirAsTildeDefaultFile(t *testing.T) {
ctx := context.Background()
ctx = env.WithUserHomeDir(ctx, "testdata/sample-home")
file, _, err := LoadProfiles(ctx, func(p Profile) bool { return true })
profiler := FileProfilerImpl{}
file, err := profiler.GetPath(ctx)
require.NoError(t, err)
require.Equal(t, filepath.Clean("~/.databrickscfg"), file)
}
@ -57,14 +60,16 @@ func TestLoadProfilesReturnsHomedirAsTildeDefaultFile(t *testing.T) {
func TestLoadProfilesNoConfiguration(t *testing.T) {
ctx := context.Background()
ctx = env.WithUserHomeDir(ctx, "testdata")
_, _, err := LoadProfiles(ctx, func(p Profile) bool { return true })
profiler := FileProfilerImpl{}
_, err := profiler.LoadProfiles(ctx, MatchAllProfiles)
require.ErrorIs(t, err, ErrNoConfiguration)
}
func TestLoadProfilesMatchWorkspace(t *testing.T) {
ctx := context.Background()
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg")
_, profiles, err := LoadProfiles(ctx, MatchWorkspaceProfiles)
profiler := FileProfilerImpl{}
profiles, err := profiler.LoadProfiles(ctx, MatchWorkspaceProfiles)
require.NoError(t, err)
assert.Equal(t, []string{"DEFAULT", "query", "foo1", "foo2"}, profiles.Names())
}
@ -72,7 +77,8 @@ func TestLoadProfilesMatchWorkspace(t *testing.T) {
func TestLoadProfilesMatchAccount(t *testing.T) {
ctx := context.Background()
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg")
_, profiles, err := LoadProfiles(ctx, MatchAccountProfiles)
profiler := FileProfilerImpl{}
profiles, err := profiler.LoadProfiles(ctx, MatchAccountProfiles)
require.NoError(t, err)
assert.Equal(t, []string{"acc"}, profiles.Names())
}

View File

@ -0,0 +1,25 @@
package profile
import "context"
type InMemoryProfiler struct {
Profiles Profiles
}
// GetPath implements Profiler.
func (i InMemoryProfiler) GetPath(context.Context) (string, error) {
return "<in memory>", nil
}
// LoadProfiles implements Profiler.
func (i InMemoryProfiler) LoadProfiles(ctx context.Context, f ProfileMatchFunction) (Profiles, error) {
res := make(Profiles, 0)
for _, p := range i.Profiles {
if f(p) {
res = append(res, p)
}
}
return res, nil
}
var _ Profiler = InMemoryProfiler{}

View File

@ -0,0 +1,49 @@
package profile
import (
"strings"
"github.com/databricks/databricks-sdk-go/config"
)
// Profile holds a subset of the keys in a databrickscfg profile.
// It should only be used for prompting and filtering.
// Use its name to construct a config.Config.
type Profile struct {
Name string
Host string
AccountID string
}
func (p Profile) Cloud() string {
cfg := config.Config{Host: p.Host}
switch {
case cfg.IsAws():
return "AWS"
case cfg.IsAzure():
return "Azure"
case cfg.IsGcp():
return "GCP"
default:
return ""
}
}
type Profiles []Profile
// SearchCaseInsensitive implements the promptui.Searcher interface.
// This allows the user to immediately starting typing to narrow down the list.
func (p Profiles) SearchCaseInsensitive(input string, index int) bool {
input = strings.ToLower(input)
name := strings.ToLower(p[index].Name)
host := strings.ToLower(p[index].Host)
return strings.Contains(name, input) || strings.Contains(host, input)
}
func (p Profiles) Names() []string {
names := make([]string, len(p))
for i, v := range p {
names[i] = v.Name
}
return names
}

View File

@ -0,0 +1,32 @@
package profile
import (
"context"
)
type ProfileMatchFunction func(Profile) bool
func MatchWorkspaceProfiles(p Profile) bool {
return p.AccountID == ""
}
func MatchAccountProfiles(p Profile) bool {
return p.Host != "" && p.AccountID != ""
}
func MatchAllProfiles(p Profile) bool {
return true
}
func WithName(name string) ProfileMatchFunction {
return func(p Profile) bool {
return p.Name == name
}
}
type Profiler interface {
LoadProfiles(context.Context, ProfileMatchFunction) (Profiles, error)
GetPath(context.Context) (string, error)
}
var DefaultProfiler = FileProfilerImpl{}

View File

@ -1,150 +0,0 @@
package databrickscfg
import (
"context"
"errors"
"fmt"
"io/fs"
"path/filepath"
"strings"
"github.com/databricks/cli/libs/env"
"github.com/databricks/databricks-sdk-go/config"
"github.com/spf13/cobra"
)
// Profile holds a subset of the keys in a databrickscfg profile.
// It should only be used for prompting and filtering.
// Use its name to construct a config.Config.
type Profile struct {
Name string
Host string
AccountID string
}
func (p Profile) Cloud() string {
cfg := config.Config{Host: p.Host}
switch {
case cfg.IsAws():
return "AWS"
case cfg.IsAzure():
return "Azure"
case cfg.IsGcp():
return "GCP"
default:
return ""
}
}
type Profiles []Profile
func (p Profiles) Names() []string {
names := make([]string, len(p))
for i, v := range p {
names[i] = v.Name
}
return names
}
// SearchCaseInsensitive implements the promptui.Searcher interface.
// This allows the user to immediately starting typing to narrow down the list.
func (p Profiles) SearchCaseInsensitive(input string, index int) bool {
input = strings.ToLower(input)
name := strings.ToLower(p[index].Name)
host := strings.ToLower(p[index].Host)
return strings.Contains(name, input) || strings.Contains(host, input)
}
type ProfileMatchFunction func(Profile) bool
func MatchWorkspaceProfiles(p Profile) bool {
return p.AccountID == ""
}
func MatchAccountProfiles(p Profile) bool {
return p.Host != "" && p.AccountID != ""
}
func MatchAllProfiles(p Profile) bool {
return true
}
// Get the path to the .databrickscfg file, falling back to the default in the current user's home directory.
func GetPath(ctx context.Context) (string, error) {
configFile := env.Get(ctx, "DATABRICKS_CONFIG_FILE")
if configFile == "" {
configFile = "~/.databrickscfg"
}
if strings.HasPrefix(configFile, "~") {
homedir, err := env.UserHomeDir(ctx)
if err != nil {
return "", err
}
configFile = filepath.Join(homedir, configFile[1:])
}
return configFile, nil
}
var ErrNoConfiguration = errors.New("no configuration file found")
func Get(ctx context.Context) (*config.File, error) {
path, err := GetPath(ctx)
if err != nil {
return nil, fmt.Errorf("cannot determine Databricks config file path: %w", err)
}
configFile, err := config.LoadFile(path)
if errors.Is(err, fs.ErrNotExist) {
// downstreams depend on ErrNoConfiguration. TODO: expose this error through SDK
return nil, fmt.Errorf("%w at %s; please create one by running 'databricks configure'", ErrNoConfiguration, path)
} else if err != nil {
return nil, err
}
return configFile, nil
}
func LoadProfiles(ctx context.Context, fn ProfileMatchFunction) (file string, profiles Profiles, err error) {
f, err := Get(ctx)
if err != nil {
return "", nil, fmt.Errorf("cannot load Databricks config file: %w", err)
}
// Replace homedir with ~ if applicable.
// This is to make the output more readable.
file = filepath.Clean(f.Path())
home, err := env.UserHomeDir(ctx)
if err != nil {
return "", nil, err
}
homedir := filepath.Clean(home)
if strings.HasPrefix(file, homedir) {
file = "~" + file[len(homedir):]
}
// Iterate over sections and collect matching profiles.
for _, v := range f.Sections() {
all := v.KeysHash()
host, ok := all["host"]
if !ok {
// invalid profile
continue
}
profile := Profile{
Name: v.Name(),
Host: host,
AccountID: all["account_id"],
}
if fn(profile) {
profiles = append(profiles, profile)
}
}
return
}
func ProfileCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
_, profiles, err := LoadProfiles(cmd.Context(), MatchAllProfiles)
if err != nil {
return nil, cobra.ShellCompDirectiveError
}
return profiles.Names(), cobra.ShellCompDirectiveNoFileComp
}