hackathon device code flow

This commit is contained in:
Miles Yucht 2024-09-12 10:49:27 +02:00
parent fb077a85d2
commit f9675ab8ea
No known key found for this signature in database
GPG Key ID: CDA4D62DC9997360
3 changed files with 100 additions and 11 deletions

View File

@ -84,10 +84,13 @@ depends on the existing profiles you have set in your configuration file
var loginTimeout time.Duration var loginTimeout time.Duration
var configureCluster bool var configureCluster bool
var deviceCode bool
cmd.Flags().DurationVar(&loginTimeout, "timeout", defaultTimeout, cmd.Flags().DurationVar(&loginTimeout, "timeout", defaultTimeout,
"Timeout for completing login challenge in the browser") "Timeout for completing login challenge in the browser")
cmd.Flags().BoolVar(&configureCluster, "configure-cluster", false, cmd.Flags().BoolVar(&configureCluster, "configure-cluster", false,
"Prompts to configure cluster") "Prompts to configure cluster")
cmd.Flags().BoolVar(&deviceCode, "device-code", false,
"Use device code flow for authentication")
cmd.RunE = func(cmd *cobra.Command, args []string) error { cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context() ctx := cmd.Context()
@ -120,7 +123,11 @@ depends on the existing profiles you have set in your configuration file
ctx, cancel := context.WithTimeout(ctx, loginTimeout) ctx, cancel := context.WithTimeout(ctx, loginTimeout)
defer cancel() defer cancel()
err = persistentAuth.Challenge(ctx) if deviceCode {
err = persistentAuth.DeviceCode(ctx)
} else {
err = persistentAuth.Challenge(ctx)
}
if err != nil { if err != nil {
return err return err
} }

View File

@ -55,9 +55,16 @@ type PersistentAuth struct {
Host string Host string
AccountID string AccountID string
http *httpclient.ApiClient // The client used when making requests to Databricks OAuth endpoints.
cache cache.TokenCache http *httpclient.ApiClient
ln net.Listener
// A token cache for OAuth access & refresh tokens.
cache cache.TokenCache
// A listener used to receive the OAuth callback. Not used for device-code flow.
ln net.Listener
// A function to open a URL in the user's browser. Not used for device-code flow.
browser func(string) error browser func(string) error
} }
@ -113,11 +120,41 @@ func (a *PersistentAuth) ProfileName() string {
return split[0] return split[0]
} }
func (a *PersistentAuth) DeviceCode(ctx context.Context) error {
err := a.init(ctx)
if err != nil {
return fmt.Errorf("init: %w", err)
}
cfg, err := a.oauth2Config(ctx)
if err != nil {
return err
}
ctx = a.http.InContextForOAuth2(ctx)
deviceAuthResp, err := cfg.DeviceAuth(ctx)
if err != nil {
return fmt.Errorf("error initiating device code flow: %w", err)
}
fmt.Printf("To authenticate, please visit %s and enter the code %s\n", deviceAuthResp.VerificationURI, deviceAuthResp.UserCode)
token, err := cfg.DeviceAccessToken(ctx, deviceAuthResp)
if err != nil {
return fmt.Errorf("error retrieving token: %w", err)
}
err = a.cache.Store(a.key(), token)
if err != nil {
return fmt.Errorf("store: %w", err)
}
return nil
}
func (a *PersistentAuth) Challenge(ctx context.Context) error { func (a *PersistentAuth) Challenge(ctx context.Context) error {
err := a.init(ctx) err := a.init(ctx)
if err != nil { if err != nil {
return fmt.Errorf("init: %w", err) return fmt.Errorf("init: %w", err)
} }
err = a.initU2M(ctx)
if err != nil {
return fmt.Errorf("init: %w", err)
}
cfg, err := a.oauth2Config(ctx) cfg, err := a.oauth2Config(ctx)
if err != nil { if err != nil {
return err return err
@ -143,6 +180,8 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error {
return nil return nil
} }
// init validates that the host and account id are set and initializes the http client and token cache.
// It should be called before any other method on PersistentAuth.
func (a *PersistentAuth) init(ctx context.Context) error { func (a *PersistentAuth) init(ctx context.Context) error {
if a.Host == "" && a.AccountID == "" { if a.Host == "" && a.AccountID == "" {
return ErrFetchCredentials return ErrFetchCredentials
@ -153,6 +192,11 @@ func (a *PersistentAuth) init(ctx context.Context) error {
if a.cache == nil { if a.cache == nil {
a.cache = cache.GetTokenCache(ctx) a.cache = cache.GetTokenCache(ctx)
} }
return nil
}
// initU2M initializes the listener for the user-to-machine flow. It does not need to be called for device-code flow.
func (a *PersistentAuth) initU2M(ctx context.Context) error {
if a.browser == nil { if a.browser == nil {
a.browser = browser.OpenURL a.browser = browser.OpenURL
} }
@ -186,8 +230,9 @@ func (a *PersistentAuth) oidcEndpoints(ctx context.Context) (*oauthAuthorization
prefix := a.key() prefix := a.key()
if a.AccountID != "" { if a.AccountID != "" {
return &oauthAuthorizationServer{ return &oauthAuthorizationServer{
AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix), AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix),
TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix), TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix),
DeviceAuthorizationEndpoint: fmt.Sprintf("%s/v1/device_authorization", prefix),
}, nil }, nil
} }
var oauthEndpoints oauthAuthorizationServer var oauthEndpoints oauthAuthorizationServer
@ -219,9 +264,10 @@ func (a *PersistentAuth) oauth2Config(ctx context.Context) (*oauth2.Config, erro
return &oauth2.Config{ return &oauth2.Config{
ClientID: appClientID, ClientID: appClientID,
Endpoint: oauth2.Endpoint{ Endpoint: oauth2.Endpoint{
AuthURL: endpoints.AuthorizationEndpoint, AuthURL: endpoints.AuthorizationEndpoint,
TokenURL: endpoints.TokenEndpoint, TokenURL: endpoints.TokenEndpoint,
AuthStyle: oauth2.AuthStyleInParams, DeviceAuthURL: endpoints.DeviceAuthorizationEndpoint,
AuthStyle: oauth2.AuthStyleInParams,
}, },
RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr), RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr),
Scopes: scopes, Scopes: scopes,
@ -260,6 +306,7 @@ func (a *PersistentAuth) randomString(size int) string {
} }
type oauthAuthorizationServer struct { type oauthAuthorizationServer struct {
AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize
TokenEndpoint string `json:"token_endpoint"` // ../v1/token TokenEndpoint string `json:"token_endpoint"` // ../v1/token
DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint"` // ../v1/device_authorization
} }

View File

@ -228,3 +228,38 @@ func TestChallengeFailed(t *testing.T) {
assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request") assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request")
}) })
} }
func TestDeviceCode_Account(t *testing.T) {
qa.HTTPFixtures{
{
Method: "POST",
Resource: "/oidc/accounts/xyz/v1/device_authorization",
Response: `{"device_code":"abc","user_code":"def","verification_uri":"ghi"}`,
},
{
Method: "POST",
Resource: "/oidc/accounts/xyz/v1/token",
Response: `access_token=jkl&refresh_token=mnop`,
},
}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
ctx = useInsecureOAuthHttpClientForTests(ctx)
tokenStored := false
p := &PersistentAuth{
Host: c.Config.Host,
AccountID: "xyz",
cache: &tokenCacheMock{
store: func(key string, tok *oauth2.Token) error {
assert.Equal(t, fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host), key)
assert.Equal(t, "mnop", tok.RefreshToken)
tokenStored = true
return nil
},
},
}
defer p.Close()
err := p.DeviceCode(ctx)
assert.NoError(t, err)
assert.True(t, tokenStored)
})
}