mirror of https://github.com/databricks/cli.git
Improve token refresh flow (#1434)
## Changes Currently, there are a number of issues with the non-happy-path flows for token refresh in the CLI. If the token refresh fails, the raw error message is presented to the user, as seen below. This message is very difficult for users to interpret and doesn't give any clear direction on how to resolve this issue. ``` Error: token refresh: Post "https://adb-<WSID>.azuredatabricks.net/oidc/v1/token": http 400: {"error":"invalid_request","error_description":"Refresh token is invalid"} ``` When logging in again, I've noticed that the timeout for logging in is very short, only 45 seconds. If a user is using a password manager and needs to login to that first, or needs to do MFA, 45 seconds may not be enough time. to an account-level profile, it is quite frustrating for users to need to re-enter account ID information when that information is already stored in the user's `.databrickscfg` file. This PR tackles these two issues. First, the presentation of error messages from `databricks auth token` is improved substantially by converting the `error` into a human-readable message. When the refresh token is invalid, it will present a command for the user to run to reauthenticate. If the token fetching failed for some other reason, that reason will be presented in a nice way, providing front-line debugging steps and ultimately redirecting users to file a ticket at this repo if they can't resolve the issue themselves. After this PR, the new error message is: ``` Error: a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run `.databricks/databricks auth login --host https://adb-<WSID>.azuredatabricks.net` ``` To improve the login flow, this PR modifies `databricks auth login` to auto-complete the account ID from the profile when present. Additionally, it increases the login timeout from 45 seconds to 1 hour to give the user sufficient time to login as needed. To test this change, I needed to refactor some components of the CLI around profile management, the token cache, and the API client used to fetch OAuth tokens. These are now settable in the context, and a demonstration of how they can be set and used is found in `auth_test.go`. Separately, this also demonstrates a sort-of integration test of the CLI by executing the Cobra command for `databricks auth token` from tests, which may be useful for testing other end-to-end functionality in the CLI. In particular, I believe this is necessary in order to set flag values (like the `--profile` flag in this case) for use in testing. ## Tests Unit tests cover the unhappy and happy paths using the mocked API client, token cache, and profiler. Manually tested --------- Co-authored-by: Pieter Noordhuis <pieter.noordhuis@databricks.com>
This commit is contained in:
parent
157877a152
commit
f7d4b272f4
|
@ -1,6 +1,8 @@
|
||||||
package config_tests
|
package config_tests
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
@ -9,12 +11,14 @@ import (
|
||||||
func TestGitAutoLoadWithEnvironment(t *testing.T) {
|
func TestGitAutoLoadWithEnvironment(t *testing.T) {
|
||||||
b := load(t, "./environments_autoload_git")
|
b := load(t, "./environments_autoload_git")
|
||||||
assert.True(t, b.Config.Bundle.Git.Inferred)
|
assert.True(t, b.Config.Bundle.Git.Inferred)
|
||||||
assert.Contains(t, b.Config.Bundle.Git.OriginURL, "/cli")
|
validUrl := strings.Contains(b.Config.Bundle.Git.OriginURL, "/cli") || strings.Contains(b.Config.Bundle.Git.OriginURL, "/bricks")
|
||||||
|
assert.True(t, validUrl, fmt.Sprintf("Expected URL to contain '/cli' or '/bricks', got %s", b.Config.Bundle.Git.OriginURL))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGitManuallySetBranchWithEnvironment(t *testing.T) {
|
func TestGitManuallySetBranchWithEnvironment(t *testing.T) {
|
||||||
b := loadTarget(t, "./environments_autoload_git", "production")
|
b := loadTarget(t, "./environments_autoload_git", "production")
|
||||||
assert.False(t, b.Config.Bundle.Git.Inferred)
|
assert.False(t, b.Config.Bundle.Git.Inferred)
|
||||||
assert.Equal(t, "main", b.Config.Bundle.Git.Branch)
|
assert.Equal(t, "main", b.Config.Bundle.Git.Branch)
|
||||||
assert.Contains(t, b.Config.Bundle.Git.OriginURL, "/cli")
|
validUrl := strings.Contains(b.Config.Bundle.Git.OriginURL, "/cli") || strings.Contains(b.Config.Bundle.Git.OriginURL, "/bricks")
|
||||||
|
assert.True(t, validUrl, fmt.Sprintf("Expected URL to contain '/cli' or '/bricks', got %s", b.Config.Bundle.Git.OriginURL))
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,8 @@ package config_tests
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/databricks/cli/bundle"
|
"github.com/databricks/cli/bundle"
|
||||||
|
@ -13,14 +15,16 @@ import (
|
||||||
func TestGitAutoLoad(t *testing.T) {
|
func TestGitAutoLoad(t *testing.T) {
|
||||||
b := load(t, "./autoload_git")
|
b := load(t, "./autoload_git")
|
||||||
assert.True(t, b.Config.Bundle.Git.Inferred)
|
assert.True(t, b.Config.Bundle.Git.Inferred)
|
||||||
assert.Contains(t, b.Config.Bundle.Git.OriginURL, "/cli")
|
validUrl := strings.Contains(b.Config.Bundle.Git.OriginURL, "/cli") || strings.Contains(b.Config.Bundle.Git.OriginURL, "/bricks")
|
||||||
|
assert.True(t, validUrl, fmt.Sprintf("Expected URL to contain '/cli' or '/bricks', got %s", b.Config.Bundle.Git.OriginURL))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGitManuallySetBranch(t *testing.T) {
|
func TestGitManuallySetBranch(t *testing.T) {
|
||||||
b := loadTarget(t, "./autoload_git", "production")
|
b := loadTarget(t, "./autoload_git", "production")
|
||||||
assert.False(t, b.Config.Bundle.Git.Inferred)
|
assert.False(t, b.Config.Bundle.Git.Inferred)
|
||||||
assert.Equal(t, "main", b.Config.Bundle.Git.Branch)
|
assert.Equal(t, "main", b.Config.Bundle.Git.Branch)
|
||||||
assert.Contains(t, b.Config.Bundle.Git.OriginURL, "/cli")
|
validUrl := strings.Contains(b.Config.Bundle.Git.OriginURL, "/cli") || strings.Contains(b.Config.Bundle.Git.OriginURL, "/bricks")
|
||||||
|
assert.True(t, validUrl, fmt.Sprintf("Expected URL to contain '/cli' or '/bricks', got %s", b.Config.Bundle.Git.OriginURL))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGitBundleBranchValidation(t *testing.T) {
|
func TestGitBundleBranchValidation(t *testing.T) {
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/databricks/cli/libs/databrickscfg"
|
"github.com/databricks/cli/libs/databrickscfg/profile"
|
||||||
"github.com/databricks/databricks-sdk-go/config"
|
"github.com/databricks/databricks-sdk-go/config"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"gopkg.in/ini.v1"
|
"gopkg.in/ini.v1"
|
||||||
|
@ -70,7 +70,7 @@ func resolveSection(cfg *config.Config, iniFile *config.File) (*ini.Section, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadFromDatabricksCfg(ctx context.Context, cfg *config.Config) error {
|
func loadFromDatabricksCfg(ctx context.Context, cfg *config.Config) error {
|
||||||
iniFile, err := databrickscfg.Get(ctx)
|
iniFile, err := profile.DefaultProfiler.Get(ctx)
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
// it's fine not to have ~/.databrickscfg
|
// it's fine not to have ~/.databrickscfg
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"github.com/databricks/cli/libs/cmdio"
|
"github.com/databricks/cli/libs/cmdio"
|
||||||
"github.com/databricks/cli/libs/databrickscfg"
|
"github.com/databricks/cli/libs/databrickscfg"
|
||||||
"github.com/databricks/cli/libs/databrickscfg/cfgpickers"
|
"github.com/databricks/cli/libs/databrickscfg/cfgpickers"
|
||||||
|
"github.com/databricks/cli/libs/databrickscfg/profile"
|
||||||
"github.com/databricks/databricks-sdk-go"
|
"github.com/databricks/databricks-sdk-go"
|
||||||
"github.com/databricks/databricks-sdk-go/config"
|
"github.com/databricks/databricks-sdk-go/config"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
@ -31,6 +32,7 @@ func configureHost(ctx context.Context, persistentAuth *auth.PersistentAuth, arg
|
||||||
}
|
}
|
||||||
|
|
||||||
const minimalDbConnectVersion = "13.1"
|
const minimalDbConnectVersion = "13.1"
|
||||||
|
const defaultTimeout = 1 * time.Hour
|
||||||
|
|
||||||
func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
|
func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
|
||||||
defaultConfigPath := "~/.databrickscfg"
|
defaultConfigPath := "~/.databrickscfg"
|
||||||
|
@ -84,7 +86,7 @@ depends on the existing profiles you have set in your configuration file
|
||||||
|
|
||||||
var loginTimeout time.Duration
|
var loginTimeout time.Duration
|
||||||
var configureCluster bool
|
var configureCluster bool
|
||||||
cmd.Flags().DurationVar(&loginTimeout, "timeout", auth.DefaultTimeout,
|
cmd.Flags().DurationVar(&loginTimeout, "timeout", defaultTimeout,
|
||||||
"Timeout for completing login challenge in the browser")
|
"Timeout for completing login challenge in the browser")
|
||||||
cmd.Flags().BoolVar(&configureCluster, "configure-cluster", false,
|
cmd.Flags().BoolVar(&configureCluster, "configure-cluster", false,
|
||||||
"Prompts to configure cluster")
|
"Prompts to configure cluster")
|
||||||
|
@ -108,7 +110,7 @@ depends on the existing profiles you have set in your configuration file
|
||||||
profileName = profile
|
profileName = profile
|
||||||
}
|
}
|
||||||
|
|
||||||
err := setHost(ctx, profileName, persistentAuth, args)
|
err := setHostAndAccountId(ctx, profileName, persistentAuth, args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -118,16 +120,9 @@ depends on the existing profiles you have set in your configuration file
|
||||||
// Otherwise it will complain about non existing profile because it was not yet saved.
|
// Otherwise it will complain about non existing profile because it was not yet saved.
|
||||||
cfg := config.Config{
|
cfg := config.Config{
|
||||||
Host: persistentAuth.Host,
|
Host: persistentAuth.Host,
|
||||||
|
AccountID: persistentAuth.AccountID,
|
||||||
AuthType: "databricks-cli",
|
AuthType: "databricks-cli",
|
||||||
}
|
}
|
||||||
if cfg.IsAccountClient() && persistentAuth.AccountID == "" {
|
|
||||||
accountId, err := promptForAccountID(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
persistentAuth.AccountID = accountId
|
|
||||||
}
|
|
||||||
cfg.AccountID = persistentAuth.AccountID
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, loginTimeout)
|
ctx, cancel := context.WithTimeout(ctx, loginTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -172,15 +167,15 @@ depends on the existing profiles you have set in your configuration file
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
func setHost(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error {
|
func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error {
|
||||||
|
profiler := profile.GetProfiler(ctx)
|
||||||
// If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile.
|
// If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile.
|
||||||
_, profiles, err := databrickscfg.LoadProfiles(ctx, func(p databrickscfg.Profile) bool {
|
profiles, err := profiler.LoadProfiles(ctx, profile.WithName(profileName))
|
||||||
return p.Name == profileName
|
|
||||||
})
|
|
||||||
// Tolerate ErrNoConfiguration here, as we will write out a configuration as part of the login flow.
|
// Tolerate ErrNoConfiguration here, as we will write out a configuration as part of the login flow.
|
||||||
if err != nil && !errors.Is(err, databrickscfg.ErrNoConfiguration) {
|
if err != nil && !errors.Is(err, profile.ErrNoConfiguration) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if persistentAuth.Host == "" {
|
if persistentAuth.Host == "" {
|
||||||
if len(profiles) > 0 && profiles[0].Host != "" {
|
if len(profiles) > 0 && profiles[0].Host != "" {
|
||||||
persistentAuth.Host = profiles[0].Host
|
persistentAuth.Host = profiles[0].Host
|
||||||
|
@ -188,5 +183,17 @@ func setHost(ctx context.Context, profileName string, persistentAuth *auth.Persi
|
||||||
configureHost(ctx, persistentAuth, args, 0)
|
configureHost(ctx, persistentAuth, args, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
isAccountClient := (&config.Config{Host: persistentAuth.Host}).IsAccountClient()
|
||||||
|
if isAccountClient && persistentAuth.AccountID == "" {
|
||||||
|
if len(profiles) > 0 && profiles[0].AccountID != "" {
|
||||||
|
persistentAuth.AccountID = profiles[0].AccountID
|
||||||
|
} else {
|
||||||
|
accountId, err := promptForAccountID(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
persistentAuth.AccountID = accountId
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,6 @@ import (
|
||||||
func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) {
|
func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./imaginary-file/databrickscfg")
|
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./imaginary-file/databrickscfg")
|
||||||
err := setHost(ctx, "foo", &auth.PersistentAuth{Host: "test"}, []string{})
|
err := setHostAndAccountId(ctx, "foo", &auth.PersistentAuth{Host: "test"}, []string{})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/databricks/cli/libs/cmdio"
|
"github.com/databricks/cli/libs/cmdio"
|
||||||
"github.com/databricks/cli/libs/databrickscfg"
|
"github.com/databricks/cli/libs/databrickscfg/profile"
|
||||||
"github.com/databricks/cli/libs/log"
|
"github.com/databricks/cli/libs/log"
|
||||||
"github.com/databricks/databricks-sdk-go"
|
"github.com/databricks/databricks-sdk-go"
|
||||||
"github.com/databricks/databricks-sdk-go/config"
|
"github.com/databricks/databricks-sdk-go/config"
|
||||||
|
@ -94,7 +94,7 @@ func newProfilesCommand() *cobra.Command {
|
||||||
|
|
||||||
cmd.RunE = func(cmd *cobra.Command, args []string) error {
|
cmd.RunE = func(cmd *cobra.Command, args []string) error {
|
||||||
var profiles []*profileMetadata
|
var profiles []*profileMetadata
|
||||||
iniFile, err := databrickscfg.Get(cmd.Context())
|
iniFile, err := profile.DefaultProfiler.Get(cmd.Context())
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
// return empty list for non-configured machines
|
// return empty list for non-configured machines
|
||||||
iniFile = &config.File{
|
iniFile = &config.File{
|
||||||
|
|
|
@ -4,12 +4,44 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/databricks/cli/libs/auth"
|
"github.com/databricks/cli/libs/auth"
|
||||||
|
"github.com/databricks/databricks-sdk-go/httpclient"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type tokenErrorResponse struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
ErrorDescription string `json:"error_description"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildLoginCommand(profile string, persistentAuth *auth.PersistentAuth) string {
|
||||||
|
executable := os.Args[0]
|
||||||
|
cmd := []string{
|
||||||
|
executable,
|
||||||
|
"auth",
|
||||||
|
"login",
|
||||||
|
}
|
||||||
|
if profile != "" {
|
||||||
|
cmd = append(cmd, "--profile", profile)
|
||||||
|
} else {
|
||||||
|
cmd = append(cmd, "--host", persistentAuth.Host)
|
||||||
|
if persistentAuth.AccountID != "" {
|
||||||
|
cmd = append(cmd, "--account-id", persistentAuth.AccountID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(cmd, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func helpfulError(profile string, persistentAuth *auth.PersistentAuth) string {
|
||||||
|
loginMsg := buildLoginCommand(profile, persistentAuth)
|
||||||
|
return fmt.Sprintf("Try logging in again with `%s` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", loginMsg)
|
||||||
|
}
|
||||||
|
|
||||||
func newTokenCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
|
func newTokenCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
|
||||||
cmd := &cobra.Command{
|
cmd := &cobra.Command{
|
||||||
Use: "token [HOST]",
|
Use: "token [HOST]",
|
||||||
|
@ -17,7 +49,7 @@ func newTokenCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
|
||||||
}
|
}
|
||||||
|
|
||||||
var tokenTimeout time.Duration
|
var tokenTimeout time.Duration
|
||||||
cmd.Flags().DurationVar(&tokenTimeout, "timeout", auth.DefaultTimeout,
|
cmd.Flags().DurationVar(&tokenTimeout, "timeout", defaultTimeout,
|
||||||
"Timeout for acquiring a token.")
|
"Timeout for acquiring a token.")
|
||||||
|
|
||||||
cmd.RunE = func(cmd *cobra.Command, args []string) error {
|
cmd.RunE = func(cmd *cobra.Command, args []string) error {
|
||||||
|
@ -29,11 +61,11 @@ func newTokenCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
|
||||||
profileName = profileFlag.Value.String()
|
profileName = profileFlag.Value.String()
|
||||||
// If a profile is provided we read the host from the .databrickscfg file
|
// If a profile is provided we read the host from the .databrickscfg file
|
||||||
if profileName != "" && len(args) > 0 {
|
if profileName != "" && len(args) > 0 {
|
||||||
return errors.New("providing both a profile and a host parameters is not supported")
|
return errors.New("providing both a profile and host is not supported")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := setHost(ctx, profileName, persistentAuth, args)
|
err := setHostAndAccountId(ctx, profileName, persistentAuth, args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -42,8 +74,21 @@ func newTokenCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
|
||||||
ctx, cancel := context.WithTimeout(ctx, tokenTimeout)
|
ctx, cancel := context.WithTimeout(ctx, tokenTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
t, err := persistentAuth.Load(ctx)
|
t, err := persistentAuth.Load(ctx)
|
||||||
|
var httpErr *httpclient.HttpError
|
||||||
|
if errors.As(err, &httpErr) {
|
||||||
|
helpMsg := helpfulError(profileName, persistentAuth)
|
||||||
|
t := &tokenErrorResponse{}
|
||||||
|
err = json.Unmarshal([]byte(httpErr.Message), t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("unexpected parsing token response: %w. %s", err, helpMsg)
|
||||||
|
}
|
||||||
|
if t.ErrorDescription == "Refresh token is invalid" {
|
||||||
|
return fmt.Errorf("a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run `%s`", buildLoginCommand(profileName, persistentAuth))
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("unexpected error refreshing token: %s. %s", t.ErrorDescription, helpMsg)
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("unexpected error refreshing token: %w. %s", err, helpfulError(profileName, persistentAuth))
|
||||||
}
|
}
|
||||||
raw, err := json.MarshalIndent(t, "", " ")
|
raw, err := json.MarshalIndent(t, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -0,0 +1,168 @@
|
||||||
|
package auth_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/databricks/cli/cmd"
|
||||||
|
"github.com/databricks/cli/libs/auth"
|
||||||
|
"github.com/databricks/cli/libs/auth/cache"
|
||||||
|
"github.com/databricks/cli/libs/databrickscfg/profile"
|
||||||
|
"github.com/databricks/databricks-sdk-go/httpclient"
|
||||||
|
"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var refreshFailureTokenResponse = fixtures.HTTPFixture{
|
||||||
|
MatchAny: true,
|
||||||
|
Status: 401,
|
||||||
|
Response: map[string]string{
|
||||||
|
"error": "invalid_request",
|
||||||
|
"error_description": "Refresh token is invalid",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var refreshFailureInvalidResponse = fixtures.HTTPFixture{
|
||||||
|
MatchAny: true,
|
||||||
|
Status: 401,
|
||||||
|
Response: "Not json",
|
||||||
|
}
|
||||||
|
|
||||||
|
var refreshFailureOtherError = fixtures.HTTPFixture{
|
||||||
|
MatchAny: true,
|
||||||
|
Status: 401,
|
||||||
|
Response: map[string]string{
|
||||||
|
"error": "other_error",
|
||||||
|
"error_description": "Databricks is down",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var refreshSuccessTokenResponse = fixtures.HTTPFixture{
|
||||||
|
MatchAny: true,
|
||||||
|
Status: 200,
|
||||||
|
Response: map[string]string{
|
||||||
|
"access_token": "new-access-token",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": "3600",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateToken(t *testing.T, resp string) {
|
||||||
|
res := map[string]string{}
|
||||||
|
err := json.Unmarshal([]byte(resp), &res)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "new-access-token", res["access_token"])
|
||||||
|
assert.Equal(t, "Bearer", res["token_type"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func getContextForTest(f fixtures.HTTPFixture) context.Context {
|
||||||
|
profiler := profile.InMemoryProfiler{
|
||||||
|
Profiles: profile.Profiles{
|
||||||
|
{
|
||||||
|
Name: "expired",
|
||||||
|
Host: "https://accounts.cloud.databricks.com",
|
||||||
|
AccountID: "expired",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "active",
|
||||||
|
Host: "https://accounts.cloud.databricks.com",
|
||||||
|
AccountID: "active",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tokenCache := &cache.InMemoryTokenCache{
|
||||||
|
Tokens: map[string]*oauth2.Token{
|
||||||
|
"https://accounts.cloud.databricks.com/oidc/accounts/expired": {
|
||||||
|
RefreshToken: "expired",
|
||||||
|
},
|
||||||
|
"https://accounts.cloud.databricks.com/oidc/accounts/active": {
|
||||||
|
RefreshToken: "active",
|
||||||
|
Expiry: time.Now().Add(1 * time.Hour), // Hopefully unit tests don't take an hour to run
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := httpclient.NewApiClient(httpclient.ClientConfig{
|
||||||
|
Transport: fixtures.SliceTransport{f},
|
||||||
|
})
|
||||||
|
ctx := profile.WithProfiler(context.Background(), profiler)
|
||||||
|
ctx = cache.WithTokenCache(ctx, tokenCache)
|
||||||
|
ctx = auth.WithApiClientForOAuth(ctx, client)
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCobraCmdForTest(f fixtures.HTTPFixture) (*cobra.Command, *bytes.Buffer) {
|
||||||
|
ctx := getContextForTest(f)
|
||||||
|
c := cmd.New(ctx)
|
||||||
|
output := &bytes.Buffer{}
|
||||||
|
c.SetOut(output)
|
||||||
|
return c, output
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenCmdWithProfilePrintsHelpfulLoginMessageOnRefreshFailure(t *testing.T) {
|
||||||
|
cmd, output := getCobraCmdForTest(refreshFailureTokenResponse)
|
||||||
|
cmd.SetArgs([]string{"auth", "token", "--profile", "expired"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
out := output.String()
|
||||||
|
assert.Empty(t, out)
|
||||||
|
assert.ErrorContains(t, err, "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run ")
|
||||||
|
assert.ErrorContains(t, err, "auth login --profile expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenCmdWithHostPrintsHelpfulLoginMessageOnRefreshFailure(t *testing.T) {
|
||||||
|
cmd, output := getCobraCmdForTest(refreshFailureTokenResponse)
|
||||||
|
cmd.SetArgs([]string{"auth", "token", "--host", "https://accounts.cloud.databricks.com", "--account-id", "expired"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
out := output.String()
|
||||||
|
assert.Empty(t, out)
|
||||||
|
assert.ErrorContains(t, err, "a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run ")
|
||||||
|
assert.ErrorContains(t, err, "auth login --host https://accounts.cloud.databricks.com --account-id expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenCmdInvalidResponse(t *testing.T) {
|
||||||
|
cmd, output := getCobraCmdForTest(refreshFailureInvalidResponse)
|
||||||
|
cmd.SetArgs([]string{"auth", "token", "--profile", "active"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
out := output.String()
|
||||||
|
assert.Empty(t, out)
|
||||||
|
assert.ErrorContains(t, err, "unexpected parsing token response: invalid character 'N' looking for beginning of value. Try logging in again with ")
|
||||||
|
assert.ErrorContains(t, err, "auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenCmdOtherErrorResponse(t *testing.T) {
|
||||||
|
cmd, output := getCobraCmdForTest(refreshFailureOtherError)
|
||||||
|
cmd.SetArgs([]string{"auth", "token", "--profile", "active"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
out := output.String()
|
||||||
|
assert.Empty(t, out)
|
||||||
|
assert.ErrorContains(t, err, "unexpected error refreshing token: Databricks is down. Try logging in again with ")
|
||||||
|
assert.ErrorContains(t, err, "auth login --profile active` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenCmdWithProfileSuccess(t *testing.T) {
|
||||||
|
cmd, output := getCobraCmdForTest(refreshSuccessTokenResponse)
|
||||||
|
cmd.SetArgs([]string{"auth", "token", "--profile", "active"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
out := output.String()
|
||||||
|
validateToken(t, out)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenCmdWithHostSuccess(t *testing.T) {
|
||||||
|
cmd, output := getCobraCmdForTest(refreshSuccessTokenResponse)
|
||||||
|
cmd.SetArgs([]string{"auth", "token", "--host", "https://accounts.cloud.databricks.com", "--account-id", "expired"})
|
||||||
|
err := cmd.Execute()
|
||||||
|
|
||||||
|
out := output.String()
|
||||||
|
validateToken(t, out)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
|
@ -11,8 +11,8 @@ import (
|
||||||
"github.com/databricks/cli/cmd/labs/github"
|
"github.com/databricks/cli/cmd/labs/github"
|
||||||
"github.com/databricks/cli/cmd/labs/unpack"
|
"github.com/databricks/cli/cmd/labs/unpack"
|
||||||
"github.com/databricks/cli/libs/cmdio"
|
"github.com/databricks/cli/libs/cmdio"
|
||||||
"github.com/databricks/cli/libs/databrickscfg"
|
|
||||||
"github.com/databricks/cli/libs/databrickscfg/cfgpickers"
|
"github.com/databricks/cli/libs/databrickscfg/cfgpickers"
|
||||||
|
"github.com/databricks/cli/libs/databrickscfg/profile"
|
||||||
"github.com/databricks/cli/libs/log"
|
"github.com/databricks/cli/libs/log"
|
||||||
"github.com/databricks/cli/libs/process"
|
"github.com/databricks/cli/libs/process"
|
||||||
"github.com/databricks/cli/libs/python"
|
"github.com/databricks/cli/libs/python"
|
||||||
|
@ -89,7 +89,7 @@ func (i *installer) Install(ctx context.Context) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w, err := i.login(ctx)
|
w, err := i.login(ctx)
|
||||||
if err != nil && errors.Is(err, databrickscfg.ErrNoConfiguration) {
|
if err != nil && errors.Is(err, profile.ErrNoConfiguration) {
|
||||||
cfg, err := i.Installer.envAwareConfig(ctx)
|
cfg, err := i.Installer.envAwareConfig(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/databricks/cli/libs/cmdio"
|
"github.com/databricks/cli/libs/cmdio"
|
||||||
"github.com/databricks/cli/libs/databrickscfg"
|
"github.com/databricks/cli/libs/databrickscfg/profile"
|
||||||
"github.com/databricks/databricks-sdk-go"
|
"github.com/databricks/databricks-sdk-go"
|
||||||
"github.com/databricks/databricks-sdk-go/config"
|
"github.com/databricks/databricks-sdk-go/config"
|
||||||
"github.com/manifoldco/promptui"
|
"github.com/manifoldco/promptui"
|
||||||
|
@ -37,7 +37,7 @@ func (e ErrNoAccountProfiles) Error() string {
|
||||||
|
|
||||||
func initProfileFlag(cmd *cobra.Command) {
|
func initProfileFlag(cmd *cobra.Command) {
|
||||||
cmd.PersistentFlags().StringP("profile", "p", "", "~/.databrickscfg profile")
|
cmd.PersistentFlags().StringP("profile", "p", "", "~/.databrickscfg profile")
|
||||||
cmd.RegisterFlagCompletionFunc("profile", databrickscfg.ProfileCompletion)
|
cmd.RegisterFlagCompletionFunc("profile", profile.ProfileCompletion)
|
||||||
}
|
}
|
||||||
|
|
||||||
func profileFlagValue(cmd *cobra.Command) (string, bool) {
|
func profileFlagValue(cmd *cobra.Command) (string, bool) {
|
||||||
|
@ -111,27 +111,29 @@ func MustAccountClient(cmd *cobra.Command, args []string) error {
|
||||||
cfg := &config.Config{}
|
cfg := &config.Config{}
|
||||||
|
|
||||||
// The command-line profile flag takes precedence over DATABRICKS_CONFIG_PROFILE.
|
// The command-line profile flag takes precedence over DATABRICKS_CONFIG_PROFILE.
|
||||||
profile, hasProfileFlag := profileFlagValue(cmd)
|
pr, hasProfileFlag := profileFlagValue(cmd)
|
||||||
if hasProfileFlag {
|
if hasProfileFlag {
|
||||||
cfg.Profile = profile
|
cfg.Profile = pr
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := cmd.Context()
|
ctx := cmd.Context()
|
||||||
ctx = context.WithValue(ctx, &configUsed, cfg)
|
ctx = context.WithValue(ctx, &configUsed, cfg)
|
||||||
cmd.SetContext(ctx)
|
cmd.SetContext(ctx)
|
||||||
|
|
||||||
|
profiler := profile.GetProfiler(ctx)
|
||||||
|
|
||||||
if cfg.Profile == "" {
|
if cfg.Profile == "" {
|
||||||
// account-level CLI was not really done before, so here are the assumptions:
|
// account-level CLI was not really done before, so here are the assumptions:
|
||||||
// 1. only admins will have account configured
|
// 1. only admins will have account configured
|
||||||
// 2. 99% of admins will have access to just one account
|
// 2. 99% of admins will have access to just one account
|
||||||
// hence, we don't need to create a special "DEFAULT_ACCOUNT" profile yet
|
// hence, we don't need to create a special "DEFAULT_ACCOUNT" profile yet
|
||||||
_, profiles, err := databrickscfg.LoadProfiles(cmd.Context(), databrickscfg.MatchAccountProfiles)
|
profiles, err := profiler.LoadProfiles(cmd.Context(), profile.MatchAccountProfiles)
|
||||||
if err == nil && len(profiles) == 1 {
|
if err == nil && len(profiles) == 1 {
|
||||||
cfg.Profile = profiles[0].Name
|
cfg.Profile = profiles[0].Name
|
||||||
}
|
}
|
||||||
|
|
||||||
// if there is no config file, we don't want to fail and instead just skip it
|
// if there is no config file, we don't want to fail and instead just skip it
|
||||||
if err != nil && !errors.Is(err, databrickscfg.ErrNoConfiguration) {
|
if err != nil && !errors.Is(err, profile.ErrNoConfiguration) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -233,11 +235,12 @@ func SetAccountClient(ctx context.Context, a *databricks.AccountClient) context.
|
||||||
}
|
}
|
||||||
|
|
||||||
func AskForWorkspaceProfile(ctx context.Context) (string, error) {
|
func AskForWorkspaceProfile(ctx context.Context) (string, error) {
|
||||||
path, err := databrickscfg.GetPath(ctx)
|
profiler := profile.GetProfiler(ctx)
|
||||||
|
path, err := profiler.GetPath(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("cannot determine Databricks config file path: %w", err)
|
return "", fmt.Errorf("cannot determine Databricks config file path: %w", err)
|
||||||
}
|
}
|
||||||
file, profiles, err := databrickscfg.LoadProfiles(ctx, databrickscfg.MatchWorkspaceProfiles)
|
profiles, err := profiler.LoadProfiles(ctx, profile.MatchWorkspaceProfiles)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -248,7 +251,7 @@ func AskForWorkspaceProfile(ctx context.Context) (string, error) {
|
||||||
return profiles[0].Name, nil
|
return profiles[0].Name, nil
|
||||||
}
|
}
|
||||||
i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
|
i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
|
||||||
Label: fmt.Sprintf("Workspace profiles defined in %s", file),
|
Label: fmt.Sprintf("Workspace profiles defined in %s", path),
|
||||||
Items: profiles,
|
Items: profiles,
|
||||||
Searcher: profiles.SearchCaseInsensitive,
|
Searcher: profiles.SearchCaseInsensitive,
|
||||||
StartInSearchMode: true,
|
StartInSearchMode: true,
|
||||||
|
@ -266,11 +269,12 @@ func AskForWorkspaceProfile(ctx context.Context) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func AskForAccountProfile(ctx context.Context) (string, error) {
|
func AskForAccountProfile(ctx context.Context) (string, error) {
|
||||||
path, err := databrickscfg.GetPath(ctx)
|
profiler := profile.GetProfiler(ctx)
|
||||||
|
path, err := profiler.GetPath(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("cannot determine Databricks config file path: %w", err)
|
return "", fmt.Errorf("cannot determine Databricks config file path: %w", err)
|
||||||
}
|
}
|
||||||
file, profiles, err := databrickscfg.LoadProfiles(ctx, databrickscfg.MatchAccountProfiles)
|
profiles, err := profiler.LoadProfiles(ctx, profile.MatchAccountProfiles)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -281,7 +285,7 @@ func AskForAccountProfile(ctx context.Context) (string, error) {
|
||||||
return profiles[0].Name, nil
|
return profiles[0].Name, nil
|
||||||
}
|
}
|
||||||
i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
|
i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
|
||||||
Label: fmt.Sprintf("Account profiles defined in %s", file),
|
Label: fmt.Sprintf("Account profiles defined in %s", path),
|
||||||
Items: profiles,
|
Items: profiles,
|
||||||
Searcher: profiles.SearchCaseInsensitive,
|
Searcher: profiles.SearchCaseInsensitive,
|
||||||
StartInSearchMode: true,
|
StartInSearchMode: true,
|
||||||
|
|
|
@ -1,106 +1,26 @@
|
||||||
package cache
|
package cache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io/fs"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
type TokenCache interface {
|
||||||
// where the token cache is stored
|
Store(key string, t *oauth2.Token) error
|
||||||
tokenCacheFile = ".databricks/token-cache.json"
|
Lookup(key string) (*oauth2.Token, error)
|
||||||
|
|
||||||
// only the owner of the file has full execute, read, and write access
|
|
||||||
ownerExecReadWrite = 0o700
|
|
||||||
|
|
||||||
// only the owner of the file has full read and write access
|
|
||||||
ownerReadWrite = 0o600
|
|
||||||
|
|
||||||
// format versioning leaves some room for format improvement
|
|
||||||
tokenCacheVersion = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
var ErrNotConfigured = errors.New("databricks OAuth is not configured for this host")
|
|
||||||
|
|
||||||
// this implementation requires the calling code to do a machine-wide lock,
|
|
||||||
// otherwise the file might get corrupt.
|
|
||||||
type TokenCache struct {
|
|
||||||
Version int `json:"version"`
|
|
||||||
Tokens map[string]*oauth2.Token `json:"tokens"`
|
|
||||||
|
|
||||||
fileLocation string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TokenCache) Store(key string, t *oauth2.Token) error {
|
var tokenCache int
|
||||||
err := c.load()
|
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
func WithTokenCache(ctx context.Context, c TokenCache) context.Context {
|
||||||
dir := filepath.Dir(c.fileLocation)
|
return context.WithValue(ctx, &tokenCache, c)
|
||||||
err = os.MkdirAll(dir, ownerExecReadWrite)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("mkdir: %w", err)
|
|
||||||
}
|
|
||||||
} else if err != nil {
|
|
||||||
return fmt.Errorf("load: %w", err)
|
|
||||||
}
|
|
||||||
c.Version = tokenCacheVersion
|
|
||||||
if c.Tokens == nil {
|
|
||||||
c.Tokens = map[string]*oauth2.Token{}
|
|
||||||
}
|
|
||||||
c.Tokens[key] = t
|
|
||||||
raw, err := json.MarshalIndent(c, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("marshal: %w", err)
|
|
||||||
}
|
|
||||||
return os.WriteFile(c.fileLocation, raw, ownerReadWrite)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TokenCache) Lookup(key string) (*oauth2.Token, error) {
|
func GetTokenCache(ctx context.Context) TokenCache {
|
||||||
err := c.load()
|
c, ok := ctx.Value(&tokenCache).(TokenCache)
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
|
||||||
return nil, ErrNotConfigured
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, fmt.Errorf("load: %w", err)
|
|
||||||
}
|
|
||||||
t, ok := c.Tokens[key]
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, ErrNotConfigured
|
return &FileTokenCache{}
|
||||||
}
|
}
|
||||||
return t, nil
|
return c
|
||||||
}
|
|
||||||
|
|
||||||
func (c *TokenCache) location() (string, error) {
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("home: %w", err)
|
|
||||||
}
|
|
||||||
return filepath.Join(home, tokenCacheFile), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *TokenCache) load() error {
|
|
||||||
loc, err := c.location()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.fileLocation = loc
|
|
||||||
raw, err := os.ReadFile(loc)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read: %w", err)
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(raw, c)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse: %w", err)
|
|
||||||
}
|
|
||||||
if c.Version != tokenCacheVersion {
|
|
||||||
// in the later iterations we could do state upgraders,
|
|
||||||
// so that we transform token cache from v1 to v2 without
|
|
||||||
// losing the tokens and asking the user to re-authenticate.
|
|
||||||
return fmt.Errorf("needs version %d, got version %d",
|
|
||||||
tokenCacheVersion, c.Version)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,108 @@
|
||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// where the token cache is stored
|
||||||
|
tokenCacheFile = ".databricks/token-cache.json"
|
||||||
|
|
||||||
|
// only the owner of the file has full execute, read, and write access
|
||||||
|
ownerExecReadWrite = 0o700
|
||||||
|
|
||||||
|
// only the owner of the file has full read and write access
|
||||||
|
ownerReadWrite = 0o600
|
||||||
|
|
||||||
|
// format versioning leaves some room for format improvement
|
||||||
|
tokenCacheVersion = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrNotConfigured = errors.New("databricks OAuth is not configured for this host")
|
||||||
|
|
||||||
|
// this implementation requires the calling code to do a machine-wide lock,
|
||||||
|
// otherwise the file might get corrupt.
|
||||||
|
type FileTokenCache struct {
|
||||||
|
Version int `json:"version"`
|
||||||
|
Tokens map[string]*oauth2.Token `json:"tokens"`
|
||||||
|
|
||||||
|
fileLocation string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *FileTokenCache) Store(key string, t *oauth2.Token) error {
|
||||||
|
err := c.load()
|
||||||
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
dir := filepath.Dir(c.fileLocation)
|
||||||
|
err = os.MkdirAll(dir, ownerExecReadWrite)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("mkdir: %w", err)
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("load: %w", err)
|
||||||
|
}
|
||||||
|
c.Version = tokenCacheVersion
|
||||||
|
if c.Tokens == nil {
|
||||||
|
c.Tokens = map[string]*oauth2.Token{}
|
||||||
|
}
|
||||||
|
c.Tokens[key] = t
|
||||||
|
raw, err := json.MarshalIndent(c, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal: %w", err)
|
||||||
|
}
|
||||||
|
return os.WriteFile(c.fileLocation, raw, ownerReadWrite)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) {
|
||||||
|
err := c.load()
|
||||||
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return nil, ErrNotConfigured
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, fmt.Errorf("load: %w", err)
|
||||||
|
}
|
||||||
|
t, ok := c.Tokens[key]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrNotConfigured
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *FileTokenCache) location() (string, error) {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("home: %w", err)
|
||||||
|
}
|
||||||
|
return filepath.Join(home, tokenCacheFile), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *FileTokenCache) load() error {
|
||||||
|
loc, err := c.location()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.fileLocation = loc
|
||||||
|
raw, err := os.ReadFile(loc)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read: %w", err)
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(raw, c)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse: %w", err)
|
||||||
|
}
|
||||||
|
if c.Version != tokenCacheVersion {
|
||||||
|
// in the later iterations we could do state upgraders,
|
||||||
|
// so that we transform token cache from v1 to v2 without
|
||||||
|
// losing the tokens and asking the user to re-authenticate.
|
||||||
|
return fmt.Errorf("needs version %d, got version %d",
|
||||||
|
tokenCacheVersion, c.Version)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ TokenCache = (*FileTokenCache)(nil)
|
|
@ -27,7 +27,7 @@ func setup(t *testing.T) string {
|
||||||
|
|
||||||
func TestStoreAndLookup(t *testing.T) {
|
func TestStoreAndLookup(t *testing.T) {
|
||||||
setup(t)
|
setup(t)
|
||||||
c := &TokenCache{}
|
c := &FileTokenCache{}
|
||||||
err := c.Store("x", &oauth2.Token{
|
err := c.Store("x", &oauth2.Token{
|
||||||
AccessToken: "abc",
|
AccessToken: "abc",
|
||||||
})
|
})
|
||||||
|
@ -38,7 +38,7 @@ func TestStoreAndLookup(t *testing.T) {
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
l := &TokenCache{}
|
l := &FileTokenCache{}
|
||||||
tok, err := l.Lookup("x")
|
tok, err := l.Lookup("x")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "abc", tok.AccessToken)
|
assert.Equal(t, "abc", tok.AccessToken)
|
||||||
|
@ -50,7 +50,7 @@ func TestStoreAndLookup(t *testing.T) {
|
||||||
|
|
||||||
func TestNoCacheFileReturnsErrNotConfigured(t *testing.T) {
|
func TestNoCacheFileReturnsErrNotConfigured(t *testing.T) {
|
||||||
setup(t)
|
setup(t)
|
||||||
l := &TokenCache{}
|
l := &FileTokenCache{}
|
||||||
_, err := l.Lookup("x")
|
_, err := l.Lookup("x")
|
||||||
assert.Equal(t, ErrNotConfigured, err)
|
assert.Equal(t, ErrNotConfigured, err)
|
||||||
}
|
}
|
||||||
|
@ -63,7 +63,7 @@ func TestLoadCorruptFile(t *testing.T) {
|
||||||
err = os.WriteFile(f, []byte("abc"), ownerExecReadWrite)
|
err = os.WriteFile(f, []byte("abc"), ownerExecReadWrite)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
l := &TokenCache{}
|
l := &FileTokenCache{}
|
||||||
_, err = l.Lookup("x")
|
_, err = l.Lookup("x")
|
||||||
assert.EqualError(t, err, "load: parse: invalid character 'a' looking for beginning of value")
|
assert.EqualError(t, err, "load: parse: invalid character 'a' looking for beginning of value")
|
||||||
}
|
}
|
||||||
|
@ -76,14 +76,14 @@ func TestLoadWrongVersion(t *testing.T) {
|
||||||
err = os.WriteFile(f, []byte(`{"version": 823, "things": []}`), ownerExecReadWrite)
|
err = os.WriteFile(f, []byte(`{"version": 823, "things": []}`), ownerExecReadWrite)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
l := &TokenCache{}
|
l := &FileTokenCache{}
|
||||||
_, err = l.Lookup("x")
|
_, err = l.Lookup("x")
|
||||||
assert.EqualError(t, err, "load: needs version 1, got version 823")
|
assert.EqualError(t, err, "load: needs version 1, got version 823")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDevNull(t *testing.T) {
|
func TestDevNull(t *testing.T) {
|
||||||
t.Setenv(homeEnvVar, "/dev/null")
|
t.Setenv(homeEnvVar, "/dev/null")
|
||||||
l := &TokenCache{}
|
l := &FileTokenCache{}
|
||||||
_, err := l.Lookup("x")
|
_, err := l.Lookup("x")
|
||||||
// macOS/Linux: load: read: open /dev/null/.databricks/token-cache.json:
|
// macOS/Linux: load: read: open /dev/null/.databricks/token-cache.json:
|
||||||
// windows: databricks OAuth is not configured for this host
|
// windows: databricks OAuth is not configured for this host
|
||||||
|
@ -95,7 +95,7 @@ func TestStoreOnDev(t *testing.T) {
|
||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
}
|
}
|
||||||
t.Setenv(homeEnvVar, "/dev")
|
t.Setenv(homeEnvVar, "/dev")
|
||||||
c := &TokenCache{}
|
c := &FileTokenCache{}
|
||||||
err := c.Store("x", &oauth2.Token{
|
err := c.Store("x", &oauth2.Token{
|
||||||
AccessToken: "abc",
|
AccessToken: "abc",
|
||||||
})
|
})
|
|
@ -0,0 +1,26 @@
|
||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type InMemoryTokenCache struct {
|
||||||
|
Tokens map[string]*oauth2.Token
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup implements TokenCache.
|
||||||
|
func (i *InMemoryTokenCache) Lookup(key string) (*oauth2.Token, error) {
|
||||||
|
token, ok := i.Tokens[key]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrNotConfigured
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store implements TokenCache.
|
||||||
|
func (i *InMemoryTokenCache) Store(key string, t *oauth2.Token) error {
|
||||||
|
i.Tokens[key] = t
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ TokenCache = (*InMemoryTokenCache)(nil)
|
|
@ -0,0 +1,44 @@
|
||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInMemoryCacheHit(t *testing.T) {
|
||||||
|
token := &oauth2.Token{
|
||||||
|
AccessToken: "abc",
|
||||||
|
}
|
||||||
|
c := &InMemoryTokenCache{
|
||||||
|
Tokens: map[string]*oauth2.Token{
|
||||||
|
"key": token,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
res, err := c.Lookup("key")
|
||||||
|
assert.Equal(t, res, token)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryCacheMiss(t *testing.T) {
|
||||||
|
c := &InMemoryTokenCache{
|
||||||
|
Tokens: map[string]*oauth2.Token{},
|
||||||
|
}
|
||||||
|
_, err := c.Lookup("key")
|
||||||
|
assert.ErrorIs(t, err, ErrNotConfigured)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryCacheStore(t *testing.T) {
|
||||||
|
token := &oauth2.Token{
|
||||||
|
AccessToken: "abc",
|
||||||
|
}
|
||||||
|
c := &InMemoryTokenCache{
|
||||||
|
Tokens: map[string]*oauth2.Token{},
|
||||||
|
}
|
||||||
|
err := c.Store("key", token)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
res, err := c.Lookup("key")
|
||||||
|
assert.Equal(t, res, token)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
|
@ -20,6 +20,20 @@ import (
|
||||||
"golang.org/x/oauth2/authhandler"
|
"golang.org/x/oauth2/authhandler"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var apiClientForOauth int
|
||||||
|
|
||||||
|
func WithApiClientForOAuth(ctx context.Context, c *httpclient.ApiClient) context.Context {
|
||||||
|
return context.WithValue(ctx, &apiClientForOauth, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetApiClientForOAuth(ctx context.Context) *httpclient.ApiClient {
|
||||||
|
c, ok := ctx.Value(&apiClientForOauth).(*httpclient.ApiClient)
|
||||||
|
if !ok {
|
||||||
|
return httpclient.NewApiClient(httpclient.ClientConfig{})
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// these values are predefined by Databricks as a public client
|
// these values are predefined by Databricks as a public client
|
||||||
// and is specific to this application only. Using these values
|
// and is specific to this application only. Using these values
|
||||||
|
@ -28,7 +42,7 @@ const (
|
||||||
appRedirectAddr = "localhost:8020"
|
appRedirectAddr = "localhost:8020"
|
||||||
|
|
||||||
// maximum amount of time to acquire listener on appRedirectAddr
|
// maximum amount of time to acquire listener on appRedirectAddr
|
||||||
DefaultTimeout = 45 * time.Second
|
listenerTimeout = 45 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
var ( // Databricks SDK API: `databricks OAuth is not` will be checked for presence
|
var ( // Databricks SDK API: `databricks OAuth is not` will be checked for presence
|
||||||
|
@ -42,14 +56,13 @@ type PersistentAuth struct {
|
||||||
AccountID string
|
AccountID string
|
||||||
|
|
||||||
http *httpclient.ApiClient
|
http *httpclient.ApiClient
|
||||||
cache tokenCache
|
cache cache.TokenCache
|
||||||
ln net.Listener
|
ln net.Listener
|
||||||
browser func(string) error
|
browser func(string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type tokenCache interface {
|
func (a *PersistentAuth) SetApiClient(h *httpclient.ApiClient) {
|
||||||
Store(key string, t *oauth2.Token) error
|
a.http = h
|
||||||
Lookup(key string) (*oauth2.Token, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) {
|
func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) {
|
||||||
|
@ -136,12 +149,10 @@ func (a *PersistentAuth) init(ctx context.Context) error {
|
||||||
return ErrFetchCredentials
|
return ErrFetchCredentials
|
||||||
}
|
}
|
||||||
if a.http == nil {
|
if a.http == nil {
|
||||||
a.http = httpclient.NewApiClient(httpclient.ClientConfig{
|
a.http = GetApiClientForOAuth(ctx)
|
||||||
// noop
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
if a.cache == nil {
|
if a.cache == nil {
|
||||||
a.cache = &cache.TokenCache{}
|
a.cache = cache.GetTokenCache(ctx)
|
||||||
}
|
}
|
||||||
if a.browser == nil {
|
if a.browser == nil {
|
||||||
a.browser = browser.OpenURL
|
a.browser = browser.OpenURL
|
||||||
|
@ -149,7 +160,7 @@ func (a *PersistentAuth) init(ctx context.Context) error {
|
||||||
// try acquire listener, which we also use as a machine-local
|
// try acquire listener, which we also use as a machine-local
|
||||||
// exclusive lock to prevent token cache corruption in the scope
|
// exclusive lock to prevent token cache corruption in the scope
|
||||||
// of developer machine, where this command runs.
|
// of developer machine, where this command runs.
|
||||||
listener, err := retries.Poll(ctx, DefaultTimeout,
|
listener, err := retries.Poll(ctx, listenerTimeout,
|
||||||
func() (*net.Listener, *retries.Err) {
|
func() (*net.Listener, *retries.Err) {
|
||||||
var lc net.ListenConfig
|
var lc net.ListenConfig
|
||||||
l, err := lc.Listen(ctx, "tcp", appRedirectAddr)
|
l, err := lc.Listen(ctx, "tcp", appRedirectAddr)
|
||||||
|
|
|
@ -68,7 +68,7 @@ func TestLoaderErrorsOnInvalidFile(t *testing.T) {
|
||||||
Loaders: []config.Loader{
|
Loaders: []config.Loader{
|
||||||
ResolveProfileFromHost,
|
ResolveProfileFromHost,
|
||||||
},
|
},
|
||||||
ConfigFile: "testdata/badcfg",
|
ConfigFile: "profile/testdata/badcfg",
|
||||||
Host: "https://default",
|
Host: "https://default",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ func TestLoaderSkipsNoMatchingHost(t *testing.T) {
|
||||||
Loaders: []config.Loader{
|
Loaders: []config.Loader{
|
||||||
ResolveProfileFromHost,
|
ResolveProfileFromHost,
|
||||||
},
|
},
|
||||||
ConfigFile: "testdata/databrickscfg",
|
ConfigFile: "profile/testdata/databrickscfg",
|
||||||
Host: "https://noneofthehostsmatch",
|
Host: "https://noneofthehostsmatch",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ func TestLoaderMatchingHost(t *testing.T) {
|
||||||
Loaders: []config.Loader{
|
Loaders: []config.Loader{
|
||||||
ResolveProfileFromHost,
|
ResolveProfileFromHost,
|
||||||
},
|
},
|
||||||
ConfigFile: "testdata/databrickscfg",
|
ConfigFile: "profile/testdata/databrickscfg",
|
||||||
Host: "https://default",
|
Host: "https://default",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ func TestLoaderMatchingHostWithQuery(t *testing.T) {
|
||||||
Loaders: []config.Loader{
|
Loaders: []config.Loader{
|
||||||
ResolveProfileFromHost,
|
ResolveProfileFromHost,
|
||||||
},
|
},
|
||||||
ConfigFile: "testdata/databrickscfg",
|
ConfigFile: "profile/testdata/databrickscfg",
|
||||||
Host: "https://query/?foo=bar",
|
Host: "https://query/?foo=bar",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ func TestLoaderErrorsOnMultipleMatches(t *testing.T) {
|
||||||
Loaders: []config.Loader{
|
Loaders: []config.Loader{
|
||||||
ResolveProfileFromHost,
|
ResolveProfileFromHost,
|
||||||
},
|
},
|
||||||
ConfigFile: "testdata/databrickscfg",
|
ConfigFile: "profile/testdata/databrickscfg",
|
||||||
Host: "https://foo/bar",
|
Host: "https://foo/bar",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ func TestLoadOrCreate_NotAllowed(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadOrCreate_Bad(t *testing.T) {
|
func TestLoadOrCreate_Bad(t *testing.T) {
|
||||||
path := "testdata/badcfg"
|
path := "profile/testdata/badcfg"
|
||||||
file, err := loadOrCreateConfigFile(path)
|
file, err := loadOrCreateConfigFile(path)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Nil(t, file)
|
assert.Nil(t, file)
|
||||||
|
@ -40,7 +40,7 @@ func TestMatchOrCreateSection_Direct(t *testing.T) {
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
Profile: "query",
|
Profile: "query",
|
||||||
}
|
}
|
||||||
file, err := loadOrCreateConfigFile("testdata/databrickscfg")
|
file, err := loadOrCreateConfigFile("profile/testdata/databrickscfg")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -54,7 +54,7 @@ func TestMatchOrCreateSection_AccountID(t *testing.T) {
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
AccountID: "abc",
|
AccountID: "abc",
|
||||||
}
|
}
|
||||||
file, err := loadOrCreateConfigFile("testdata/databrickscfg")
|
file, err := loadOrCreateConfigFile("profile/testdata/databrickscfg")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -68,7 +68,7 @@ func TestMatchOrCreateSection_NormalizeHost(t *testing.T) {
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
Host: "https://query/?o=abracadabra",
|
Host: "https://query/?o=abracadabra",
|
||||||
}
|
}
|
||||||
file, err := loadOrCreateConfigFile("testdata/databrickscfg")
|
file, err := loadOrCreateConfigFile("profile/testdata/databrickscfg")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -80,7 +80,7 @@ func TestMatchOrCreateSection_NormalizeHost(t *testing.T) {
|
||||||
|
|
||||||
func TestMatchOrCreateSection_NoProfileOrHost(t *testing.T) {
|
func TestMatchOrCreateSection_NoProfileOrHost(t *testing.T) {
|
||||||
cfg := &config.Config{}
|
cfg := &config.Config{}
|
||||||
file, err := loadOrCreateConfigFile("testdata/databrickscfg")
|
file, err := loadOrCreateConfigFile("profile/testdata/databrickscfg")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -92,7 +92,7 @@ func TestMatchOrCreateSection_MultipleProfiles(t *testing.T) {
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
Host: "https://foo",
|
Host: "https://foo",
|
||||||
}
|
}
|
||||||
file, err := loadOrCreateConfigFile("testdata/databrickscfg")
|
file, err := loadOrCreateConfigFile("profile/testdata/databrickscfg")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -105,7 +105,7 @@ func TestMatchOrCreateSection_NewProfile(t *testing.T) {
|
||||||
Host: "https://bar",
|
Host: "https://bar",
|
||||||
Profile: "delirium",
|
Profile: "delirium",
|
||||||
}
|
}
|
||||||
file, err := loadOrCreateConfigFile("testdata/databrickscfg")
|
file, err := loadOrCreateConfigFile("profile/testdata/databrickscfg")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
package profile
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
var profiler int
|
||||||
|
|
||||||
|
func WithProfiler(ctx context.Context, p Profiler) context.Context {
|
||||||
|
return context.WithValue(ctx, &profiler, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetProfiler(ctx context.Context) Profiler {
|
||||||
|
p, ok := ctx.Value(&profiler).(Profiler)
|
||||||
|
if !ok {
|
||||||
|
return DefaultProfiler
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
|
@ -0,0 +1,100 @@
|
||||||
|
package profile
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/databricks/cli/libs/env"
|
||||||
|
"github.com/databricks/databricks-sdk-go/config"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FileProfilerImpl struct{}
|
||||||
|
|
||||||
|
func (f FileProfilerImpl) getPath(ctx context.Context, replaceHomeDirWithTilde bool) (string, error) {
|
||||||
|
configFile := env.Get(ctx, "DATABRICKS_CONFIG_FILE")
|
||||||
|
if configFile == "" {
|
||||||
|
configFile = "~/.databrickscfg"
|
||||||
|
}
|
||||||
|
if !replaceHomeDirWithTilde {
|
||||||
|
return configFile, nil
|
||||||
|
}
|
||||||
|
homedir, err := env.UserHomeDir(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
configFile = strings.Replace(configFile, homedir, "~", 1)
|
||||||
|
return configFile, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the path to the .databrickscfg file, falling back to the default in the current user's home directory.
|
||||||
|
func (f FileProfilerImpl) GetPath(ctx context.Context) (string, error) {
|
||||||
|
fp, err := f.getPath(ctx, true)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return filepath.Clean(fp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var ErrNoConfiguration = errors.New("no configuration file found")
|
||||||
|
|
||||||
|
func (f FileProfilerImpl) Get(ctx context.Context) (*config.File, error) {
|
||||||
|
path, err := f.getPath(ctx, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot determine Databricks config file path: %w", err)
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(path, "~") {
|
||||||
|
homedir, err := env.UserHomeDir(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
path = filepath.Join(homedir, path[1:])
|
||||||
|
}
|
||||||
|
configFile, err := config.LoadFile(path)
|
||||||
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
// downstreams depend on ErrNoConfiguration. TODO: expose this error through SDK
|
||||||
|
return nil, fmt.Errorf("%w at %s; please create one by running 'databricks configure'", ErrNoConfiguration, path)
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return configFile, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f FileProfilerImpl) LoadProfiles(ctx context.Context, fn ProfileMatchFunction) (profiles Profiles, err error) {
|
||||||
|
file, err := f.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot load Databricks config file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate over sections and collect matching profiles.
|
||||||
|
for _, v := range file.Sections() {
|
||||||
|
all := v.KeysHash()
|
||||||
|
host, ok := all["host"]
|
||||||
|
if !ok {
|
||||||
|
// invalid profile
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
profile := Profile{
|
||||||
|
Name: v.Name(),
|
||||||
|
Host: host,
|
||||||
|
AccountID: all["account_id"],
|
||||||
|
}
|
||||||
|
if fn(profile) {
|
||||||
|
profiles = append(profiles, profile)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func ProfileCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
profiles, err := DefaultProfiler.LoadProfiles(cmd.Context(), MatchAllProfiles)
|
||||||
|
if err != nil {
|
||||||
|
return nil, cobra.ShellCompDirectiveError
|
||||||
|
}
|
||||||
|
return profiles.Names(), cobra.ShellCompDirectiveNoFileComp
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package databrickscfg
|
package profile
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -32,7 +32,8 @@ func TestLoadProfilesReturnsHomedirAsTilde(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ctx = env.WithUserHomeDir(ctx, "testdata")
|
ctx = env.WithUserHomeDir(ctx, "testdata")
|
||||||
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg")
|
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg")
|
||||||
file, _, err := LoadProfiles(ctx, func(p Profile) bool { return true })
|
profiler := FileProfilerImpl{}
|
||||||
|
file, err := profiler.GetPath(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, filepath.Clean("~/databrickscfg"), file)
|
require.Equal(t, filepath.Clean("~/databrickscfg"), file)
|
||||||
}
|
}
|
||||||
|
@ -41,7 +42,8 @@ func TestLoadProfilesReturnsHomedirAsTildeExoticFile(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ctx = env.WithUserHomeDir(ctx, "testdata")
|
ctx = env.WithUserHomeDir(ctx, "testdata")
|
||||||
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "~/databrickscfg")
|
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "~/databrickscfg")
|
||||||
file, _, err := LoadProfiles(ctx, func(p Profile) bool { return true })
|
profiler := FileProfilerImpl{}
|
||||||
|
file, err := profiler.GetPath(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, filepath.Clean("~/databrickscfg"), file)
|
require.Equal(t, filepath.Clean("~/databrickscfg"), file)
|
||||||
}
|
}
|
||||||
|
@ -49,7 +51,8 @@ func TestLoadProfilesReturnsHomedirAsTildeExoticFile(t *testing.T) {
|
||||||
func TestLoadProfilesReturnsHomedirAsTildeDefaultFile(t *testing.T) {
|
func TestLoadProfilesReturnsHomedirAsTildeDefaultFile(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ctx = env.WithUserHomeDir(ctx, "testdata/sample-home")
|
ctx = env.WithUserHomeDir(ctx, "testdata/sample-home")
|
||||||
file, _, err := LoadProfiles(ctx, func(p Profile) bool { return true })
|
profiler := FileProfilerImpl{}
|
||||||
|
file, err := profiler.GetPath(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, filepath.Clean("~/.databrickscfg"), file)
|
require.Equal(t, filepath.Clean("~/.databrickscfg"), file)
|
||||||
}
|
}
|
||||||
|
@ -57,14 +60,16 @@ func TestLoadProfilesReturnsHomedirAsTildeDefaultFile(t *testing.T) {
|
||||||
func TestLoadProfilesNoConfiguration(t *testing.T) {
|
func TestLoadProfilesNoConfiguration(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ctx = env.WithUserHomeDir(ctx, "testdata")
|
ctx = env.WithUserHomeDir(ctx, "testdata")
|
||||||
_, _, err := LoadProfiles(ctx, func(p Profile) bool { return true })
|
profiler := FileProfilerImpl{}
|
||||||
|
_, err := profiler.LoadProfiles(ctx, MatchAllProfiles)
|
||||||
require.ErrorIs(t, err, ErrNoConfiguration)
|
require.ErrorIs(t, err, ErrNoConfiguration)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadProfilesMatchWorkspace(t *testing.T) {
|
func TestLoadProfilesMatchWorkspace(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg")
|
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg")
|
||||||
_, profiles, err := LoadProfiles(ctx, MatchWorkspaceProfiles)
|
profiler := FileProfilerImpl{}
|
||||||
|
profiles, err := profiler.LoadProfiles(ctx, MatchWorkspaceProfiles)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []string{"DEFAULT", "query", "foo1", "foo2"}, profiles.Names())
|
assert.Equal(t, []string{"DEFAULT", "query", "foo1", "foo2"}, profiles.Names())
|
||||||
}
|
}
|
||||||
|
@ -72,7 +77,8 @@ func TestLoadProfilesMatchWorkspace(t *testing.T) {
|
||||||
func TestLoadProfilesMatchAccount(t *testing.T) {
|
func TestLoadProfilesMatchAccount(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg")
|
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./testdata/databrickscfg")
|
||||||
_, profiles, err := LoadProfiles(ctx, MatchAccountProfiles)
|
profiler := FileProfilerImpl{}
|
||||||
|
profiles, err := profiler.LoadProfiles(ctx, MatchAccountProfiles)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []string{"acc"}, profiles.Names())
|
assert.Equal(t, []string{"acc"}, profiles.Names())
|
||||||
}
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
package profile
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type InMemoryProfiler struct {
|
||||||
|
Profiles Profiles
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPath implements Profiler.
|
||||||
|
func (i InMemoryProfiler) GetPath(context.Context) (string, error) {
|
||||||
|
return "<in memory>", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadProfiles implements Profiler.
|
||||||
|
func (i InMemoryProfiler) LoadProfiles(ctx context.Context, f ProfileMatchFunction) (Profiles, error) {
|
||||||
|
res := make(Profiles, 0)
|
||||||
|
for _, p := range i.Profiles {
|
||||||
|
if f(p) {
|
||||||
|
res = append(res, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Profiler = InMemoryProfiler{}
|
|
@ -0,0 +1,49 @@
|
||||||
|
package profile
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/databricks/databricks-sdk-go/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Profile holds a subset of the keys in a databrickscfg profile.
|
||||||
|
// It should only be used for prompting and filtering.
|
||||||
|
// Use its name to construct a config.Config.
|
||||||
|
type Profile struct {
|
||||||
|
Name string
|
||||||
|
Host string
|
||||||
|
AccountID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Profile) Cloud() string {
|
||||||
|
cfg := config.Config{Host: p.Host}
|
||||||
|
switch {
|
||||||
|
case cfg.IsAws():
|
||||||
|
return "AWS"
|
||||||
|
case cfg.IsAzure():
|
||||||
|
return "Azure"
|
||||||
|
case cfg.IsGcp():
|
||||||
|
return "GCP"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Profiles []Profile
|
||||||
|
|
||||||
|
// SearchCaseInsensitive implements the promptui.Searcher interface.
|
||||||
|
// This allows the user to immediately starting typing to narrow down the list.
|
||||||
|
func (p Profiles) SearchCaseInsensitive(input string, index int) bool {
|
||||||
|
input = strings.ToLower(input)
|
||||||
|
name := strings.ToLower(p[index].Name)
|
||||||
|
host := strings.ToLower(p[index].Host)
|
||||||
|
return strings.Contains(name, input) || strings.Contains(host, input)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Profiles) Names() []string {
|
||||||
|
names := make([]string, len(p))
|
||||||
|
for i, v := range p {
|
||||||
|
names[i] = v.Name
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
package profile
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProfileMatchFunction func(Profile) bool
|
||||||
|
|
||||||
|
func MatchWorkspaceProfiles(p Profile) bool {
|
||||||
|
return p.AccountID == ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func MatchAccountProfiles(p Profile) bool {
|
||||||
|
return p.Host != "" && p.AccountID != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func MatchAllProfiles(p Profile) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithName(name string) ProfileMatchFunction {
|
||||||
|
return func(p Profile) bool {
|
||||||
|
return p.Name == name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Profiler interface {
|
||||||
|
LoadProfiles(context.Context, ProfileMatchFunction) (Profiles, error)
|
||||||
|
GetPath(context.Context) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultProfiler = FileProfilerImpl{}
|
|
@ -1,150 +0,0 @@
|
||||||
package databrickscfg
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io/fs"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/databricks/cli/libs/env"
|
|
||||||
"github.com/databricks/databricks-sdk-go/config"
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Profile holds a subset of the keys in a databrickscfg profile.
|
|
||||||
// It should only be used for prompting and filtering.
|
|
||||||
// Use its name to construct a config.Config.
|
|
||||||
type Profile struct {
|
|
||||||
Name string
|
|
||||||
Host string
|
|
||||||
AccountID string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p Profile) Cloud() string {
|
|
||||||
cfg := config.Config{Host: p.Host}
|
|
||||||
switch {
|
|
||||||
case cfg.IsAws():
|
|
||||||
return "AWS"
|
|
||||||
case cfg.IsAzure():
|
|
||||||
return "Azure"
|
|
||||||
case cfg.IsGcp():
|
|
||||||
return "GCP"
|
|
||||||
default:
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Profiles []Profile
|
|
||||||
|
|
||||||
func (p Profiles) Names() []string {
|
|
||||||
names := make([]string, len(p))
|
|
||||||
for i, v := range p {
|
|
||||||
names[i] = v.Name
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
}
|
|
||||||
|
|
||||||
// SearchCaseInsensitive implements the promptui.Searcher interface.
|
|
||||||
// This allows the user to immediately starting typing to narrow down the list.
|
|
||||||
func (p Profiles) SearchCaseInsensitive(input string, index int) bool {
|
|
||||||
input = strings.ToLower(input)
|
|
||||||
name := strings.ToLower(p[index].Name)
|
|
||||||
host := strings.ToLower(p[index].Host)
|
|
||||||
return strings.Contains(name, input) || strings.Contains(host, input)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ProfileMatchFunction func(Profile) bool
|
|
||||||
|
|
||||||
func MatchWorkspaceProfiles(p Profile) bool {
|
|
||||||
return p.AccountID == ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func MatchAccountProfiles(p Profile) bool {
|
|
||||||
return p.Host != "" && p.AccountID != ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func MatchAllProfiles(p Profile) bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the path to the .databrickscfg file, falling back to the default in the current user's home directory.
|
|
||||||
func GetPath(ctx context.Context) (string, error) {
|
|
||||||
configFile := env.Get(ctx, "DATABRICKS_CONFIG_FILE")
|
|
||||||
if configFile == "" {
|
|
||||||
configFile = "~/.databrickscfg"
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(configFile, "~") {
|
|
||||||
homedir, err := env.UserHomeDir(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
configFile = filepath.Join(homedir, configFile[1:])
|
|
||||||
}
|
|
||||||
return configFile, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var ErrNoConfiguration = errors.New("no configuration file found")
|
|
||||||
|
|
||||||
func Get(ctx context.Context) (*config.File, error) {
|
|
||||||
path, err := GetPath(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("cannot determine Databricks config file path: %w", err)
|
|
||||||
}
|
|
||||||
configFile, err := config.LoadFile(path)
|
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
|
||||||
// downstreams depend on ErrNoConfiguration. TODO: expose this error through SDK
|
|
||||||
return nil, fmt.Errorf("%w at %s; please create one by running 'databricks configure'", ErrNoConfiguration, path)
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return configFile, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadProfiles(ctx context.Context, fn ProfileMatchFunction) (file string, profiles Profiles, err error) {
|
|
||||||
f, err := Get(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, fmt.Errorf("cannot load Databricks config file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace homedir with ~ if applicable.
|
|
||||||
// This is to make the output more readable.
|
|
||||||
file = filepath.Clean(f.Path())
|
|
||||||
home, err := env.UserHomeDir(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
homedir := filepath.Clean(home)
|
|
||||||
if strings.HasPrefix(file, homedir) {
|
|
||||||
file = "~" + file[len(homedir):]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over sections and collect matching profiles.
|
|
||||||
for _, v := range f.Sections() {
|
|
||||||
all := v.KeysHash()
|
|
||||||
host, ok := all["host"]
|
|
||||||
if !ok {
|
|
||||||
// invalid profile
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
profile := Profile{
|
|
||||||
Name: v.Name(),
|
|
||||||
Host: host,
|
|
||||||
AccountID: all["account_id"],
|
|
||||||
}
|
|
||||||
if fn(profile) {
|
|
||||||
profiles = append(profiles, profile)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func ProfileCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
|
||||||
_, profiles, err := LoadProfiles(cmd.Context(), MatchAllProfiles)
|
|
||||||
if err != nil {
|
|
||||||
return nil, cobra.ShellCompDirectiveError
|
|
||||||
}
|
|
||||||
return profiles.Names(), cobra.ShellCompDirectiveNoFileComp
|
|
||||||
}
|
|
Loading…
Reference in New Issue