This commit is contained in:
Miles Yucht 2024-10-17 14:19:52 +00:00 committed by GitHub
commit 742a1058de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 107 additions and 14 deletions

View File

@ -84,10 +84,13 @@ depends on the existing profiles you have set in your configuration file
var loginTimeout time.Duration
var configureCluster bool
var deviceCode bool
cmd.Flags().DurationVar(&loginTimeout, "timeout", defaultTimeout,
"Timeout for completing login challenge in the browser")
cmd.Flags().BoolVar(&configureCluster, "configure-cluster", false,
"Prompts to configure cluster")
cmd.Flags().BoolVar(&deviceCode, "device-code", false,
"Use device code flow for authentication")
cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
@ -120,7 +123,11 @@ depends on the existing profiles you have set in your configuration file
ctx, cancel := context.WithTimeout(ctx, loginTimeout)
defer cancel()
err = persistentAuth.Challenge(ctx)
if deviceCode {
err = persistentAuth.DeviceCode(ctx)
} else {
err = persistentAuth.Challenge(ctx)
}
if err != nil {
return err
}

View File

@ -46,18 +46,26 @@ const (
)
var ( // Databricks SDK API: `databricks OAuth is not` will be checked for presence
ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host")
ErrNotConfigured = errors.New("databricks OAuth is not configured for this host")
ErrFetchCredentials = errors.New("cannot fetch credentials")
ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host")
ErrNotConfigured = errors.New("databricks OAuth is not configured for this host")
ErrFetchCredentials = errors.New("cannot fetch credentials")
ErrDeviceCodeNotSupported = errors.New("device code flow is not supported for this host")
)
type PersistentAuth struct {
Host string
AccountID string
http *httpclient.ApiClient
cache cache.TokenCache
ln net.Listener
// The client used when making requests to Databricks OAuth endpoints.
http *httpclient.ApiClient
// A token cache for OAuth access & refresh tokens.
cache cache.TokenCache
// A listener used to receive the OAuth callback. Not used for device-code flow.
ln net.Listener
// A function to open a URL in the user's browser. Not used for device-code flow.
browser func(string) error
}
@ -113,11 +121,44 @@ func (a *PersistentAuth) ProfileName() string {
return split[0]
}
func (a *PersistentAuth) DeviceCode(ctx context.Context) error {
err := a.init(ctx)
if err != nil {
return fmt.Errorf("init: %w", err)
}
cfg, err := a.oauth2Config(ctx)
if err != nil {
return err
}
if cfg.Endpoint.DeviceAuthURL == "" {
return ErrDeviceCodeNotSupported
}
ctx = a.http.InContextForOAuth2(ctx)
deviceAuthResp, err := cfg.DeviceAuth(ctx)
if err != nil {
return fmt.Errorf("error initiating device code flow: %w", err)
}
fmt.Printf("To authenticate, please visit %s and enter the code %s\n", deviceAuthResp.VerificationURI, deviceAuthResp.UserCode)
token, err := cfg.DeviceAccessToken(ctx, deviceAuthResp)
if err != nil {
return fmt.Errorf("error retrieving token: %w", err)
}
err = a.cache.Store(a.key(), token)
if err != nil {
return fmt.Errorf("store: %w", err)
}
return nil
}
func (a *PersistentAuth) Challenge(ctx context.Context) error {
err := a.init(ctx)
if err != nil {
return fmt.Errorf("init: %w", err)
}
err = a.initU2M(ctx)
if err != nil {
return fmt.Errorf("init: %w", err)
}
cfg, err := a.oauth2Config(ctx)
if err != nil {
return err
@ -143,6 +184,8 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error {
return nil
}
// init validates that the host and account id are set and initializes the http client and token cache.
// It should be called before any other method on PersistentAuth.
func (a *PersistentAuth) init(ctx context.Context) error {
if a.Host == "" && a.AccountID == "" {
return ErrFetchCredentials
@ -153,6 +196,11 @@ func (a *PersistentAuth) init(ctx context.Context) error {
if a.cache == nil {
a.cache = cache.GetTokenCache(ctx)
}
return nil
}
// initU2M initializes the listener for the user-to-machine flow. It does not need to be called for device-code flow.
func (a *PersistentAuth) initU2M(ctx context.Context) error {
if a.browser == nil {
a.browser = browser.OpenURL
}
@ -186,8 +234,9 @@ func (a *PersistentAuth) oidcEndpoints(ctx context.Context) (*oauthAuthorization
prefix := a.key()
if a.AccountID != "" {
return &oauthAuthorizationServer{
AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix),
TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix),
AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix),
TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix),
DeviceAuthorizationEndpoint: fmt.Sprintf("%s/v1/device_authorization", prefix),
}, nil
}
var oauthEndpoints oauthAuthorizationServer
@ -219,9 +268,10 @@ func (a *PersistentAuth) oauth2Config(ctx context.Context) (*oauth2.Config, erro
return &oauth2.Config{
ClientID: appClientID,
Endpoint: oauth2.Endpoint{
AuthURL: endpoints.AuthorizationEndpoint,
TokenURL: endpoints.TokenEndpoint,
AuthStyle: oauth2.AuthStyleInParams,
AuthURL: endpoints.AuthorizationEndpoint,
TokenURL: endpoints.TokenEndpoint,
DeviceAuthURL: endpoints.DeviceAuthorizationEndpoint,
AuthStyle: oauth2.AuthStyleInParams,
},
RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr),
Scopes: scopes,
@ -260,6 +310,7 @@ func (a *PersistentAuth) randomString(size int) string {
}
type oauthAuthorizationServer struct {
AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize
TokenEndpoint string `json:"token_endpoint"` // ../v1/token
AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize
TokenEndpoint string `json:"token_endpoint"` // ../v1/token
DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint"` // ../v1/device_authorization
}

View File

@ -228,3 +228,38 @@ func TestChallengeFailed(t *testing.T) {
assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request")
})
}
func TestDeviceCode_Account(t *testing.T) {
qa.HTTPFixtures{
{
Method: "POST",
Resource: "/oidc/accounts/xyz/v1/device_authorization",
Response: `{"device_code":"abc","user_code":"def","verification_uri":"ghi"}`,
},
{
Method: "POST",
Resource: "/oidc/accounts/xyz/v1/token",
Response: `access_token=jkl&refresh_token=mnop`,
},
}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
ctx = useInsecureOAuthHttpClientForTests(ctx)
tokenStored := false
p := &PersistentAuth{
Host: c.Config.Host,
AccountID: "xyz",
cache: &tokenCacheMock{
store: func(key string, tok *oauth2.Token) error {
assert.Equal(t, fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host), key)
assert.Equal(t, "mnop", tok.RefreshToken)
tokenStored = true
return nil
},
},
}
defer p.Close()
err := p.DeviceCode(ctx)
assert.NoError(t, err)
assert.True(t, tokenStored)
})
}