From 945d522dab796265cc9bb71b3692316f21ec02b5 Mon Sep 17 00:00:00 2001 From: Serge Smertin <259697+nfx@users.noreply.github.com> Date: Mon, 11 Mar 2024 23:24:23 +0100 Subject: [PATCH] Propagate correct `User-Agent` for CLI (#1264) ## Changes This PR migrates `databricks auth login` HTTP client to the one from Go SDK, making API calls more robust and containing our unified user agent. ## Tests Unit tests left almost unchanged --- libs/auth/oauth.go | 47 ++++++++++++++++------------------------- libs/auth/oauth_test.go | 33 ++++++++++++----------------- 2 files changed, 32 insertions(+), 48 deletions(-) diff --git a/libs/auth/oauth.go b/libs/auth/oauth.go index dd27d04b..4ce0d4de 100644 --- a/libs/auth/oauth.go +++ b/libs/auth/oauth.go @@ -6,16 +6,14 @@ import ( "crypto/sha256" _ "embed" "encoding/base64" - "encoding/json" "errors" "fmt" - "io" "net" - "net/http" "strings" "time" "github.com/databricks/cli/libs/auth/cache" + "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/retries" "github.com/pkg/browser" "golang.org/x/oauth2" @@ -43,16 +41,12 @@ type PersistentAuth struct { Host string AccountID string - http httpGet + http *httpclient.ApiClient cache tokenCache ln net.Listener browser func(string) error } -type httpGet interface { - Get(string) (*http.Response, error) -} - type tokenCache interface { Store(key string, t *oauth2.Token) error Lookup(key string) (*oauth2.Token, error) @@ -77,10 +71,12 @@ func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) { } // OAuth2 config is invoked only for expired tokens to speed up // the happy path in the token retrieval - cfg, err := a.oauth2Config() + cfg, err := a.oauth2Config(ctx) if err != nil { return nil, err } + // make OAuth2 library use our client + ctx = a.http.InContextForOAuth2(ctx) // eagerly refresh token refreshed, err := cfg.TokenSource(ctx, t).Token() if err != nil { @@ -110,7 +106,7 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error { if err != nil { return fmt.Errorf("init: %w", err) } - cfg, err := a.oauth2Config() + cfg, err := a.oauth2Config(ctx) if err != nil { return err } @@ -120,6 +116,8 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error { } defer cb.Close() state, pkce := a.stateAndPKCE() + // make OAuth2 library use our client + ctx = a.http.InContextForOAuth2(ctx) ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce) t, err := ts.Token() if err != nil { @@ -138,7 +136,9 @@ func (a *PersistentAuth) init(ctx context.Context) error { return ErrFetchCredentials } if a.http == nil { - a.http = http.DefaultClient + a.http = httpclient.NewApiClient(httpclient.ClientConfig{ + // noop + }) } if a.cache == nil { a.cache = &cache.TokenCache{} @@ -172,7 +172,7 @@ func (a *PersistentAuth) Close() error { return a.ln.Close() } -func (a *PersistentAuth) oidcEndpoints() (*oauthAuthorizationServer, error) { +func (a *PersistentAuth) oidcEndpoints(ctx context.Context) (*oauthAuthorizationServer, error) { prefix := a.key() if a.AccountID != "" { return &oauthAuthorizationServer{ @@ -180,31 +180,20 @@ func (a *PersistentAuth) oidcEndpoints() (*oauthAuthorizationServer, error) { TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix), }, nil } + var oauthEndpoints oauthAuthorizationServer oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", prefix) - oidcResponse, err := a.http.Get(oidc) + err := a.http.Do(ctx, "GET", oidc, httpclient.WithResponseUnmarshal(&oauthEndpoints)) if err != nil { return nil, fmt.Errorf("fetch .well-known: %w", err) } - if oidcResponse.StatusCode != 200 { + var httpErr *httpclient.HttpError + if errors.As(err, &httpErr) && httpErr.StatusCode != 200 { return nil, ErrOAuthNotSupported } - if oidcResponse.Body == nil { - return nil, fmt.Errorf("fetch .well-known: empty body") - } - defer oidcResponse.Body.Close() - raw, err := io.ReadAll(oidcResponse.Body) - if err != nil { - return nil, fmt.Errorf("read .well-known: %w", err) - } - var oauthEndpoints oauthAuthorizationServer - err = json.Unmarshal(raw, &oauthEndpoints) - if err != nil { - return nil, fmt.Errorf("parse .well-known: %w", err) - } return &oauthEndpoints, nil } -func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) { +func (a *PersistentAuth) oauth2Config(ctx context.Context) (*oauth2.Config, error) { // in this iteration of CLI, we're using all scopes by default, // because tools like CLI and Terraform do use all apis. This // decision may be reconsidered later, once we have a proper @@ -213,7 +202,7 @@ func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) { "offline_access", "all-apis", } - endpoints, err := a.oidcEndpoints() + endpoints, err := a.oidcEndpoints(ctx) if err != nil { return nil, fmt.Errorf("oidc: %w", err) } diff --git a/libs/auth/oauth_test.go b/libs/auth/oauth_test.go index 9b5aa9ac..ea6a8061 100644 --- a/libs/auth/oauth_test.go +++ b/libs/auth/oauth_test.go @@ -5,14 +5,14 @@ import ( "crypto/tls" _ "embed" "fmt" - "io" "net/http" "net/url" - "strings" "testing" "time" "github.com/databricks/databricks-sdk-go/client" + "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/databricks/databricks-sdk-go/httpclient/fixtures" "github.com/databricks/databricks-sdk-go/qa" "github.com/stretchr/testify/assert" "golang.org/x/oauth2" @@ -24,34 +24,29 @@ func TestOidcEndpointsForAccounts(t *testing.T) { AccountID: "xyz", } defer p.Close() - s, err := p.oidcEndpoints() + s, err := p.oidcEndpoints(context.Background()) assert.NoError(t, err) assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", s.AuthorizationEndpoint) assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/token", s.TokenEndpoint) } -type mockGet func(url string) (*http.Response, error) - -func (m mockGet) Get(url string) (*http.Response, error) { - return m(url) -} - func TestOidcForWorkspace(t *testing.T) { p := &PersistentAuth{ Host: "abc", - http: mockGet(func(url string) (*http.Response, error) { - assert.Equal(t, "https://abc/oidc/.well-known/oauth-authorization-server", url) - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(`{ - "authorization_endpoint": "a", - "token_endpoint": "b" - }`)), - }, nil + http: httpclient.NewApiClient(httpclient.ClientConfig{ + Transport: fixtures.MappingTransport{ + "GET /oidc/.well-known/oauth-authorization-server": { + Status: 200, + Response: map[string]string{ + "authorization_endpoint": "a", + "token_endpoint": "b", + }, + }, + }, }), } defer p.Close() - endpoints, err := p.oidcEndpoints() + endpoints, err := p.oidcEndpoints(context.Background()) assert.NoError(t, err) assert.Equal(t, "a", endpoints.AuthorizationEndpoint) assert.Equal(t, "b", endpoints.TokenEndpoint)