mirror of https://github.com/databricks/cli.git
hackathon device code flow
This commit is contained in:
parent
fb077a85d2
commit
f9675ab8ea
|
@ -84,10 +84,13 @@ depends on the existing profiles you have set in your configuration file
|
||||||
|
|
||||||
var loginTimeout time.Duration
|
var loginTimeout time.Duration
|
||||||
var configureCluster bool
|
var configureCluster bool
|
||||||
|
var deviceCode bool
|
||||||
cmd.Flags().DurationVar(&loginTimeout, "timeout", defaultTimeout,
|
cmd.Flags().DurationVar(&loginTimeout, "timeout", defaultTimeout,
|
||||||
"Timeout for completing login challenge in the browser")
|
"Timeout for completing login challenge in the browser")
|
||||||
cmd.Flags().BoolVar(&configureCluster, "configure-cluster", false,
|
cmd.Flags().BoolVar(&configureCluster, "configure-cluster", false,
|
||||||
"Prompts to configure cluster")
|
"Prompts to configure cluster")
|
||||||
|
cmd.Flags().BoolVar(&deviceCode, "device-code", false,
|
||||||
|
"Use device code flow for authentication")
|
||||||
|
|
||||||
cmd.RunE = func(cmd *cobra.Command, args []string) error {
|
cmd.RunE = func(cmd *cobra.Command, args []string) error {
|
||||||
ctx := cmd.Context()
|
ctx := cmd.Context()
|
||||||
|
@ -120,7 +123,11 @@ depends on the existing profiles you have set in your configuration file
|
||||||
ctx, cancel := context.WithTimeout(ctx, loginTimeout)
|
ctx, cancel := context.WithTimeout(ctx, loginTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
err = persistentAuth.Challenge(ctx)
|
if deviceCode {
|
||||||
|
err = persistentAuth.DeviceCode(ctx)
|
||||||
|
} else {
|
||||||
|
err = persistentAuth.Challenge(ctx)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,9 +55,16 @@ type PersistentAuth struct {
|
||||||
Host string
|
Host string
|
||||||
AccountID string
|
AccountID string
|
||||||
|
|
||||||
http *httpclient.ApiClient
|
// The client used when making requests to Databricks OAuth endpoints.
|
||||||
cache cache.TokenCache
|
http *httpclient.ApiClient
|
||||||
ln net.Listener
|
|
||||||
|
// A token cache for OAuth access & refresh tokens.
|
||||||
|
cache cache.TokenCache
|
||||||
|
|
||||||
|
// A listener used to receive the OAuth callback. Not used for device-code flow.
|
||||||
|
ln net.Listener
|
||||||
|
|
||||||
|
// A function to open a URL in the user's browser. Not used for device-code flow.
|
||||||
browser func(string) error
|
browser func(string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,11 +120,41 @@ func (a *PersistentAuth) ProfileName() string {
|
||||||
return split[0]
|
return split[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *PersistentAuth) DeviceCode(ctx context.Context) error {
|
||||||
|
err := a.init(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("init: %w", err)
|
||||||
|
}
|
||||||
|
cfg, err := a.oauth2Config(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ctx = a.http.InContextForOAuth2(ctx)
|
||||||
|
deviceAuthResp, err := cfg.DeviceAuth(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error initiating device code flow: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("To authenticate, please visit %s and enter the code %s\n", deviceAuthResp.VerificationURI, deviceAuthResp.UserCode)
|
||||||
|
token, err := cfg.DeviceAccessToken(ctx, deviceAuthResp)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error retrieving token: %w", err)
|
||||||
|
}
|
||||||
|
err = a.cache.Store(a.key(), token)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("store: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *PersistentAuth) Challenge(ctx context.Context) error {
|
func (a *PersistentAuth) Challenge(ctx context.Context) error {
|
||||||
err := a.init(ctx)
|
err := a.init(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("init: %w", err)
|
return fmt.Errorf("init: %w", err)
|
||||||
}
|
}
|
||||||
|
err = a.initU2M(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("init: %w", err)
|
||||||
|
}
|
||||||
cfg, err := a.oauth2Config(ctx)
|
cfg, err := a.oauth2Config(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -143,6 +180,8 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// init validates that the host and account id are set and initializes the http client and token cache.
|
||||||
|
// It should be called before any other method on PersistentAuth.
|
||||||
func (a *PersistentAuth) init(ctx context.Context) error {
|
func (a *PersistentAuth) init(ctx context.Context) error {
|
||||||
if a.Host == "" && a.AccountID == "" {
|
if a.Host == "" && a.AccountID == "" {
|
||||||
return ErrFetchCredentials
|
return ErrFetchCredentials
|
||||||
|
@ -153,6 +192,11 @@ func (a *PersistentAuth) init(ctx context.Context) error {
|
||||||
if a.cache == nil {
|
if a.cache == nil {
|
||||||
a.cache = cache.GetTokenCache(ctx)
|
a.cache = cache.GetTokenCache(ctx)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// initU2M initializes the listener for the user-to-machine flow. It does not need to be called for device-code flow.
|
||||||
|
func (a *PersistentAuth) initU2M(ctx context.Context) error {
|
||||||
if a.browser == nil {
|
if a.browser == nil {
|
||||||
a.browser = browser.OpenURL
|
a.browser = browser.OpenURL
|
||||||
}
|
}
|
||||||
|
@ -186,8 +230,9 @@ func (a *PersistentAuth) oidcEndpoints(ctx context.Context) (*oauthAuthorization
|
||||||
prefix := a.key()
|
prefix := a.key()
|
||||||
if a.AccountID != "" {
|
if a.AccountID != "" {
|
||||||
return &oauthAuthorizationServer{
|
return &oauthAuthorizationServer{
|
||||||
AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix),
|
AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix),
|
||||||
TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix),
|
TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix),
|
||||||
|
DeviceAuthorizationEndpoint: fmt.Sprintf("%s/v1/device_authorization", prefix),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
var oauthEndpoints oauthAuthorizationServer
|
var oauthEndpoints oauthAuthorizationServer
|
||||||
|
@ -219,9 +264,10 @@ func (a *PersistentAuth) oauth2Config(ctx context.Context) (*oauth2.Config, erro
|
||||||
return &oauth2.Config{
|
return &oauth2.Config{
|
||||||
ClientID: appClientID,
|
ClientID: appClientID,
|
||||||
Endpoint: oauth2.Endpoint{
|
Endpoint: oauth2.Endpoint{
|
||||||
AuthURL: endpoints.AuthorizationEndpoint,
|
AuthURL: endpoints.AuthorizationEndpoint,
|
||||||
TokenURL: endpoints.TokenEndpoint,
|
TokenURL: endpoints.TokenEndpoint,
|
||||||
AuthStyle: oauth2.AuthStyleInParams,
|
DeviceAuthURL: endpoints.DeviceAuthorizationEndpoint,
|
||||||
|
AuthStyle: oauth2.AuthStyleInParams,
|
||||||
},
|
},
|
||||||
RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr),
|
RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr),
|
||||||
Scopes: scopes,
|
Scopes: scopes,
|
||||||
|
@ -260,6 +306,7 @@ func (a *PersistentAuth) randomString(size int) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
type oauthAuthorizationServer struct {
|
type oauthAuthorizationServer struct {
|
||||||
AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize
|
AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize
|
||||||
TokenEndpoint string `json:"token_endpoint"` // ../v1/token
|
TokenEndpoint string `json:"token_endpoint"` // ../v1/token
|
||||||
|
DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint"` // ../v1/device_authorization
|
||||||
}
|
}
|
||||||
|
|
|
@ -228,3 +228,38 @@ func TestChallengeFailed(t *testing.T) {
|
||||||
assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request")
|
assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeviceCode_Account(t *testing.T) {
|
||||||
|
qa.HTTPFixtures{
|
||||||
|
{
|
||||||
|
Method: "POST",
|
||||||
|
Resource: "/oidc/accounts/xyz/v1/device_authorization",
|
||||||
|
Response: `{"device_code":"abc","user_code":"def","verification_uri":"ghi"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Method: "POST",
|
||||||
|
Resource: "/oidc/accounts/xyz/v1/token",
|
||||||
|
Response: `access_token=jkl&refresh_token=mnop`,
|
||||||
|
},
|
||||||
|
}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
|
||||||
|
ctx = useInsecureOAuthHttpClientForTests(ctx)
|
||||||
|
tokenStored := false
|
||||||
|
p := &PersistentAuth{
|
||||||
|
Host: c.Config.Host,
|
||||||
|
AccountID: "xyz",
|
||||||
|
cache: &tokenCacheMock{
|
||||||
|
store: func(key string, tok *oauth2.Token) error {
|
||||||
|
assert.Equal(t, fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host), key)
|
||||||
|
assert.Equal(t, "mnop", tok.RefreshToken)
|
||||||
|
tokenStored = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
defer p.Close()
|
||||||
|
|
||||||
|
err := p.DeviceCode(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, tokenStored)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue