mirror of https://github.com/databricks/cli.git
Compare commits
4 Commits
3e617e1b32
...
37f5e80ba5
Author | SHA1 | Date |
---|---|---|
Miles Yucht | 37f5e80ba5 | |
shreyas-goenka | cc112961ce | |
Miles Yucht | 517e1b3310 | |
Miles Yucht | f9675ab8ea |
|
@ -84,10 +84,13 @@ depends on the existing profiles you have set in your configuration file
|
|||
|
||||
var loginTimeout time.Duration
|
||||
var configureCluster bool
|
||||
var deviceCode bool
|
||||
cmd.Flags().DurationVar(&loginTimeout, "timeout", defaultTimeout,
|
||||
"Timeout for completing login challenge in the browser")
|
||||
cmd.Flags().BoolVar(&configureCluster, "configure-cluster", false,
|
||||
"Prompts to configure cluster")
|
||||
cmd.Flags().BoolVar(&deviceCode, "device-code", false,
|
||||
"Use device code flow for authentication")
|
||||
|
||||
cmd.RunE = func(cmd *cobra.Command, args []string) error {
|
||||
ctx := cmd.Context()
|
||||
|
@ -120,7 +123,11 @@ depends on the existing profiles you have set in your configuration file
|
|||
ctx, cancel := context.WithTimeout(ctx, loginTimeout)
|
||||
defer cancel()
|
||||
|
||||
if deviceCode {
|
||||
err = persistentAuth.DeviceCode(ctx)
|
||||
} else {
|
||||
err = persistentAuth.Challenge(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -112,8 +112,8 @@ func TestAccFsMkdirWhenFileExistsAtPath(t *testing.T) {
|
|||
// assert mkdir fails
|
||||
_, _, err = RequireErrorRun(t, "fs", "mkdir", path.Join(tmpDir, "hello"))
|
||||
|
||||
// Different cloud providers return different errors.
|
||||
regex := regexp.MustCompile(`(^|: )Path is a file: .*$|(^|: )Cannot create directory .* because .* is an existing file\.$|(^|: )mkdirs\(hadoopPath: .*, permission: rwxrwxrwx\): failed$`)
|
||||
// Different cloud providers or cloud configurations return different errors.
|
||||
regex := regexp.MustCompile(`(^|: )Path is a file: .*$|(^|: )Cannot create directory .* because .* is an existing file\.$|(^|: )mkdirs\(hadoopPath: .*, permission: rwxrwxrwx\): failed$|(^|: )"The specified path already exists.".*$`)
|
||||
assert.Regexp(t, regex, err.Error())
|
||||
})
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/databricks/cli/cmd/root"
|
||||
"github.com/databricks/cli/internal/acc"
|
||||
"github.com/databricks/cli/libs/flags"
|
||||
|
||||
"github.com/databricks/cli/cmd"
|
||||
|
@ -591,13 +592,10 @@ func setupWsfsExtensionsFiler(t *testing.T) (filer.Filer, string) {
|
|||
}
|
||||
|
||||
func setupDbfsFiler(t *testing.T) (filer.Filer, string) {
|
||||
t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV"))
|
||||
_, wt := acc.WorkspaceTest(t)
|
||||
|
||||
w, err := databricks.NewWorkspaceClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
tmpDir := TemporaryDbfsDir(t, w)
|
||||
f, err := filer.NewDbfsClient(w, tmpDir)
|
||||
tmpDir := TemporaryDbfsDir(t, wt.W)
|
||||
f, err := filer.NewDbfsClient(wt.W, tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
return f, path.Join("dbfs:/", tmpDir)
|
||||
|
|
|
@ -49,15 +49,23 @@ var ( // Databricks SDK API: `databricks OAuth is not` will be checked for prese
|
|||
ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host")
|
||||
ErrNotConfigured = errors.New("databricks OAuth is not configured for this host")
|
||||
ErrFetchCredentials = errors.New("cannot fetch credentials")
|
||||
ErrDeviceCodeNotSupported = errors.New("device code flow is not supported for this host")
|
||||
)
|
||||
|
||||
type PersistentAuth struct {
|
||||
Host string
|
||||
AccountID string
|
||||
|
||||
// The client used when making requests to Databricks OAuth endpoints.
|
||||
http *httpclient.ApiClient
|
||||
|
||||
// A token cache for OAuth access & refresh tokens.
|
||||
cache cache.TokenCache
|
||||
|
||||
// A listener used to receive the OAuth callback. Not used for device-code flow.
|
||||
ln net.Listener
|
||||
|
||||
// A function to open a URL in the user's browser. Not used for device-code flow.
|
||||
browser func(string) error
|
||||
}
|
||||
|
||||
|
@ -113,11 +121,44 @@ func (a *PersistentAuth) ProfileName() string {
|
|||
return split[0]
|
||||
}
|
||||
|
||||
func (a *PersistentAuth) DeviceCode(ctx context.Context) error {
|
||||
err := a.init(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init: %w", err)
|
||||
}
|
||||
cfg, err := a.oauth2Config(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cfg.Endpoint.DeviceAuthURL == "" {
|
||||
return ErrDeviceCodeNotSupported
|
||||
}
|
||||
ctx = a.http.InContextForOAuth2(ctx)
|
||||
deviceAuthResp, err := cfg.DeviceAuth(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error initiating device code flow: %w", err)
|
||||
}
|
||||
fmt.Printf("To authenticate, please visit %s and enter the code %s\n", deviceAuthResp.VerificationURI, deviceAuthResp.UserCode)
|
||||
token, err := cfg.DeviceAccessToken(ctx, deviceAuthResp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error retrieving token: %w", err)
|
||||
}
|
||||
err = a.cache.Store(a.key(), token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PersistentAuth) Challenge(ctx context.Context) error {
|
||||
err := a.init(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init: %w", err)
|
||||
}
|
||||
err = a.initU2M(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init: %w", err)
|
||||
}
|
||||
cfg, err := a.oauth2Config(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -143,6 +184,8 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// init validates that the host and account id are set and initializes the http client and token cache.
|
||||
// It should be called before any other method on PersistentAuth.
|
||||
func (a *PersistentAuth) init(ctx context.Context) error {
|
||||
if a.Host == "" && a.AccountID == "" {
|
||||
return ErrFetchCredentials
|
||||
|
@ -153,6 +196,11 @@ func (a *PersistentAuth) init(ctx context.Context) error {
|
|||
if a.cache == nil {
|
||||
a.cache = cache.GetTokenCache(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// initU2M initializes the listener for the user-to-machine flow. It does not need to be called for device-code flow.
|
||||
func (a *PersistentAuth) initU2M(ctx context.Context) error {
|
||||
if a.browser == nil {
|
||||
a.browser = browser.OpenURL
|
||||
}
|
||||
|
@ -188,6 +236,7 @@ func (a *PersistentAuth) oidcEndpoints(ctx context.Context) (*oauthAuthorization
|
|||
return &oauthAuthorizationServer{
|
||||
AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix),
|
||||
TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix),
|
||||
DeviceAuthorizationEndpoint: fmt.Sprintf("%s/v1/device_authorization", prefix),
|
||||
}, nil
|
||||
}
|
||||
var oauthEndpoints oauthAuthorizationServer
|
||||
|
@ -221,6 +270,7 @@ func (a *PersistentAuth) oauth2Config(ctx context.Context) (*oauth2.Config, erro
|
|||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: endpoints.AuthorizationEndpoint,
|
||||
TokenURL: endpoints.TokenEndpoint,
|
||||
DeviceAuthURL: endpoints.DeviceAuthorizationEndpoint,
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr),
|
||||
|
@ -262,4 +312,5 @@ func (a *PersistentAuth) randomString(size int) string {
|
|||
type oauthAuthorizationServer struct {
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize
|
||||
TokenEndpoint string `json:"token_endpoint"` // ../v1/token
|
||||
DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint"` // ../v1/device_authorization
|
||||
}
|
||||
|
|
|
@ -228,3 +228,38 @@ func TestChallengeFailed(t *testing.T) {
|
|||
assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeviceCode_Account(t *testing.T) {
|
||||
qa.HTTPFixtures{
|
||||
{
|
||||
Method: "POST",
|
||||
Resource: "/oidc/accounts/xyz/v1/device_authorization",
|
||||
Response: `{"device_code":"abc","user_code":"def","verification_uri":"ghi"}`,
|
||||
},
|
||||
{
|
||||
Method: "POST",
|
||||
Resource: "/oidc/accounts/xyz/v1/token",
|
||||
Response: `access_token=jkl&refresh_token=mnop`,
|
||||
},
|
||||
}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
|
||||
ctx = useInsecureOAuthHttpClientForTests(ctx)
|
||||
tokenStored := false
|
||||
p := &PersistentAuth{
|
||||
Host: c.Config.Host,
|
||||
AccountID: "xyz",
|
||||
cache: &tokenCacheMock{
|
||||
store: func(key string, tok *oauth2.Token) error {
|
||||
assert.Equal(t, fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host), key)
|
||||
assert.Equal(t, "mnop", tok.RefreshToken)
|
||||
tokenStored = true
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
defer p.Close()
|
||||
|
||||
err := p.DeviceCode(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, tokenStored)
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue