mirror of https://github.com/databricks/cli.git
hackathon device code flow
This commit is contained in:
parent
fb077a85d2
commit
f9675ab8ea
|
@ -84,10 +84,13 @@ depends on the existing profiles you have set in your configuration file
|
|||
|
||||
var loginTimeout time.Duration
|
||||
var configureCluster bool
|
||||
var deviceCode bool
|
||||
cmd.Flags().DurationVar(&loginTimeout, "timeout", defaultTimeout,
|
||||
"Timeout for completing login challenge in the browser")
|
||||
cmd.Flags().BoolVar(&configureCluster, "configure-cluster", false,
|
||||
"Prompts to configure cluster")
|
||||
cmd.Flags().BoolVar(&deviceCode, "device-code", false,
|
||||
"Use device code flow for authentication")
|
||||
|
||||
cmd.RunE = func(cmd *cobra.Command, args []string) error {
|
||||
ctx := cmd.Context()
|
||||
|
@ -120,7 +123,11 @@ depends on the existing profiles you have set in your configuration file
|
|||
ctx, cancel := context.WithTimeout(ctx, loginTimeout)
|
||||
defer cancel()
|
||||
|
||||
err = persistentAuth.Challenge(ctx)
|
||||
if deviceCode {
|
||||
err = persistentAuth.DeviceCode(ctx)
|
||||
} else {
|
||||
err = persistentAuth.Challenge(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -55,9 +55,16 @@ type PersistentAuth struct {
|
|||
Host string
|
||||
AccountID string
|
||||
|
||||
http *httpclient.ApiClient
|
||||
cache cache.TokenCache
|
||||
ln net.Listener
|
||||
// The client used when making requests to Databricks OAuth endpoints.
|
||||
http *httpclient.ApiClient
|
||||
|
||||
// A token cache for OAuth access & refresh tokens.
|
||||
cache cache.TokenCache
|
||||
|
||||
// A listener used to receive the OAuth callback. Not used for device-code flow.
|
||||
ln net.Listener
|
||||
|
||||
// A function to open a URL in the user's browser. Not used for device-code flow.
|
||||
browser func(string) error
|
||||
}
|
||||
|
||||
|
@ -113,11 +120,41 @@ func (a *PersistentAuth) ProfileName() string {
|
|||
return split[0]
|
||||
}
|
||||
|
||||
func (a *PersistentAuth) DeviceCode(ctx context.Context) error {
|
||||
err := a.init(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init: %w", err)
|
||||
}
|
||||
cfg, err := a.oauth2Config(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx = a.http.InContextForOAuth2(ctx)
|
||||
deviceAuthResp, err := cfg.DeviceAuth(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error initiating device code flow: %w", err)
|
||||
}
|
||||
fmt.Printf("To authenticate, please visit %s and enter the code %s\n", deviceAuthResp.VerificationURI, deviceAuthResp.UserCode)
|
||||
token, err := cfg.DeviceAccessToken(ctx, deviceAuthResp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error retrieving token: %w", err)
|
||||
}
|
||||
err = a.cache.Store(a.key(), token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PersistentAuth) Challenge(ctx context.Context) error {
|
||||
err := a.init(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init: %w", err)
|
||||
}
|
||||
err = a.initU2M(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init: %w", err)
|
||||
}
|
||||
cfg, err := a.oauth2Config(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -143,6 +180,8 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// init validates that the host and account id are set and initializes the http client and token cache.
|
||||
// It should be called before any other method on PersistentAuth.
|
||||
func (a *PersistentAuth) init(ctx context.Context) error {
|
||||
if a.Host == "" && a.AccountID == "" {
|
||||
return ErrFetchCredentials
|
||||
|
@ -153,6 +192,11 @@ func (a *PersistentAuth) init(ctx context.Context) error {
|
|||
if a.cache == nil {
|
||||
a.cache = cache.GetTokenCache(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// initU2M initializes the listener for the user-to-machine flow. It does not need to be called for device-code flow.
|
||||
func (a *PersistentAuth) initU2M(ctx context.Context) error {
|
||||
if a.browser == nil {
|
||||
a.browser = browser.OpenURL
|
||||
}
|
||||
|
@ -186,8 +230,9 @@ func (a *PersistentAuth) oidcEndpoints(ctx context.Context) (*oauthAuthorization
|
|||
prefix := a.key()
|
||||
if a.AccountID != "" {
|
||||
return &oauthAuthorizationServer{
|
||||
AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix),
|
||||
TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix),
|
||||
AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix),
|
||||
TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix),
|
||||
DeviceAuthorizationEndpoint: fmt.Sprintf("%s/v1/device_authorization", prefix),
|
||||
}, nil
|
||||
}
|
||||
var oauthEndpoints oauthAuthorizationServer
|
||||
|
@ -219,9 +264,10 @@ func (a *PersistentAuth) oauth2Config(ctx context.Context) (*oauth2.Config, erro
|
|||
return &oauth2.Config{
|
||||
ClientID: appClientID,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: endpoints.AuthorizationEndpoint,
|
||||
TokenURL: endpoints.TokenEndpoint,
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
AuthURL: endpoints.AuthorizationEndpoint,
|
||||
TokenURL: endpoints.TokenEndpoint,
|
||||
DeviceAuthURL: endpoints.DeviceAuthorizationEndpoint,
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr),
|
||||
Scopes: scopes,
|
||||
|
@ -260,6 +306,7 @@ func (a *PersistentAuth) randomString(size int) string {
|
|||
}
|
||||
|
||||
type oauthAuthorizationServer struct {
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize
|
||||
TokenEndpoint string `json:"token_endpoint"` // ../v1/token
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize
|
||||
TokenEndpoint string `json:"token_endpoint"` // ../v1/token
|
||||
DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint"` // ../v1/device_authorization
|
||||
}
|
||||
|
|
|
@ -228,3 +228,38 @@ func TestChallengeFailed(t *testing.T) {
|
|||
assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeviceCode_Account(t *testing.T) {
|
||||
qa.HTTPFixtures{
|
||||
{
|
||||
Method: "POST",
|
||||
Resource: "/oidc/accounts/xyz/v1/device_authorization",
|
||||
Response: `{"device_code":"abc","user_code":"def","verification_uri":"ghi"}`,
|
||||
},
|
||||
{
|
||||
Method: "POST",
|
||||
Resource: "/oidc/accounts/xyz/v1/token",
|
||||
Response: `access_token=jkl&refresh_token=mnop`,
|
||||
},
|
||||
}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
|
||||
ctx = useInsecureOAuthHttpClientForTests(ctx)
|
||||
tokenStored := false
|
||||
p := &PersistentAuth{
|
||||
Host: c.Config.Host,
|
||||
AccountID: "xyz",
|
||||
cache: &tokenCacheMock{
|
||||
store: func(key string, tok *oauth2.Token) error {
|
||||
assert.Equal(t, fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host), key)
|
||||
assert.Equal(t, "mnop", tok.RefreshToken)
|
||||
tokenStored = true
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
defer p.Close()
|
||||
|
||||
err := p.DeviceCode(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, tokenStored)
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue