mirror of https://github.com/databricks/cli.git
Propagate correct `User-Agent` for CLI (#1264)
## Changes This PR migrates `databricks auth login` HTTP client to the one from Go SDK, making API calls more robust and containing our unified user agent. ## Tests Unit tests left almost unchanged
This commit is contained in:
parent
4a9a12af19
commit
945d522dab
|
@ -6,16 +6,14 @@ import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/databricks/cli/libs/auth/cache"
|
"github.com/databricks/cli/libs/auth/cache"
|
||||||
|
"github.com/databricks/databricks-sdk-go/httpclient"
|
||||||
"github.com/databricks/databricks-sdk-go/retries"
|
"github.com/databricks/databricks-sdk-go/retries"
|
||||||
"github.com/pkg/browser"
|
"github.com/pkg/browser"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
@ -43,16 +41,12 @@ type PersistentAuth struct {
|
||||||
Host string
|
Host string
|
||||||
AccountID string
|
AccountID string
|
||||||
|
|
||||||
http httpGet
|
http *httpclient.ApiClient
|
||||||
cache tokenCache
|
cache tokenCache
|
||||||
ln net.Listener
|
ln net.Listener
|
||||||
browser func(string) error
|
browser func(string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type httpGet interface {
|
|
||||||
Get(string) (*http.Response, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type tokenCache interface {
|
type tokenCache interface {
|
||||||
Store(key string, t *oauth2.Token) error
|
Store(key string, t *oauth2.Token) error
|
||||||
Lookup(key string) (*oauth2.Token, error)
|
Lookup(key string) (*oauth2.Token, error)
|
||||||
|
@ -77,10 +71,12 @@ func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) {
|
||||||
}
|
}
|
||||||
// OAuth2 config is invoked only for expired tokens to speed up
|
// OAuth2 config is invoked only for expired tokens to speed up
|
||||||
// the happy path in the token retrieval
|
// the happy path in the token retrieval
|
||||||
cfg, err := a.oauth2Config()
|
cfg, err := a.oauth2Config(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// make OAuth2 library use our client
|
||||||
|
ctx = a.http.InContextForOAuth2(ctx)
|
||||||
// eagerly refresh token
|
// eagerly refresh token
|
||||||
refreshed, err := cfg.TokenSource(ctx, t).Token()
|
refreshed, err := cfg.TokenSource(ctx, t).Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -110,7 +106,7 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("init: %w", err)
|
return fmt.Errorf("init: %w", err)
|
||||||
}
|
}
|
||||||
cfg, err := a.oauth2Config()
|
cfg, err := a.oauth2Config(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -120,6 +116,8 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
defer cb.Close()
|
defer cb.Close()
|
||||||
state, pkce := a.stateAndPKCE()
|
state, pkce := a.stateAndPKCE()
|
||||||
|
// make OAuth2 library use our client
|
||||||
|
ctx = a.http.InContextForOAuth2(ctx)
|
||||||
ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce)
|
ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce)
|
||||||
t, err := ts.Token()
|
t, err := ts.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -138,7 +136,9 @@ func (a *PersistentAuth) init(ctx context.Context) error {
|
||||||
return ErrFetchCredentials
|
return ErrFetchCredentials
|
||||||
}
|
}
|
||||||
if a.http == nil {
|
if a.http == nil {
|
||||||
a.http = http.DefaultClient
|
a.http = httpclient.NewApiClient(httpclient.ClientConfig{
|
||||||
|
// noop
|
||||||
|
})
|
||||||
}
|
}
|
||||||
if a.cache == nil {
|
if a.cache == nil {
|
||||||
a.cache = &cache.TokenCache{}
|
a.cache = &cache.TokenCache{}
|
||||||
|
@ -172,7 +172,7 @@ func (a *PersistentAuth) Close() error {
|
||||||
return a.ln.Close()
|
return a.ln.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *PersistentAuth) oidcEndpoints() (*oauthAuthorizationServer, error) {
|
func (a *PersistentAuth) oidcEndpoints(ctx context.Context) (*oauthAuthorizationServer, error) {
|
||||||
prefix := a.key()
|
prefix := a.key()
|
||||||
if a.AccountID != "" {
|
if a.AccountID != "" {
|
||||||
return &oauthAuthorizationServer{
|
return &oauthAuthorizationServer{
|
||||||
|
@ -180,31 +180,20 @@ func (a *PersistentAuth) oidcEndpoints() (*oauthAuthorizationServer, error) {
|
||||||
TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix),
|
TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
var oauthEndpoints oauthAuthorizationServer
|
||||||
oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", prefix)
|
oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", prefix)
|
||||||
oidcResponse, err := a.http.Get(oidc)
|
err := a.http.Do(ctx, "GET", oidc, httpclient.WithResponseUnmarshal(&oauthEndpoints))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("fetch .well-known: %w", err)
|
return nil, fmt.Errorf("fetch .well-known: %w", err)
|
||||||
}
|
}
|
||||||
if oidcResponse.StatusCode != 200 {
|
var httpErr *httpclient.HttpError
|
||||||
|
if errors.As(err, &httpErr) && httpErr.StatusCode != 200 {
|
||||||
return nil, ErrOAuthNotSupported
|
return nil, ErrOAuthNotSupported
|
||||||
}
|
}
|
||||||
if oidcResponse.Body == nil {
|
|
||||||
return nil, fmt.Errorf("fetch .well-known: empty body")
|
|
||||||
}
|
|
||||||
defer oidcResponse.Body.Close()
|
|
||||||
raw, err := io.ReadAll(oidcResponse.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("read .well-known: %w", err)
|
|
||||||
}
|
|
||||||
var oauthEndpoints oauthAuthorizationServer
|
|
||||||
err = json.Unmarshal(raw, &oauthEndpoints)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("parse .well-known: %w", err)
|
|
||||||
}
|
|
||||||
return &oauthEndpoints, nil
|
return &oauthEndpoints, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) {
|
func (a *PersistentAuth) oauth2Config(ctx context.Context) (*oauth2.Config, error) {
|
||||||
// in this iteration of CLI, we're using all scopes by default,
|
// in this iteration of CLI, we're using all scopes by default,
|
||||||
// because tools like CLI and Terraform do use all apis. This
|
// because tools like CLI and Terraform do use all apis. This
|
||||||
// decision may be reconsidered later, once we have a proper
|
// decision may be reconsidered later, once we have a proper
|
||||||
|
@ -213,7 +202,7 @@ func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) {
|
||||||
"offline_access",
|
"offline_access",
|
||||||
"all-apis",
|
"all-apis",
|
||||||
}
|
}
|
||||||
endpoints, err := a.oidcEndpoints()
|
endpoints, err := a.oidcEndpoints(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oidc: %w", err)
|
return nil, fmt.Errorf("oidc: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,14 +5,14 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/databricks/databricks-sdk-go/client"
|
"github.com/databricks/databricks-sdk-go/client"
|
||||||
|
"github.com/databricks/databricks-sdk-go/httpclient"
|
||||||
|
"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
|
||||||
"github.com/databricks/databricks-sdk-go/qa"
|
"github.com/databricks/databricks-sdk-go/qa"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
@ -24,34 +24,29 @@ func TestOidcEndpointsForAccounts(t *testing.T) {
|
||||||
AccountID: "xyz",
|
AccountID: "xyz",
|
||||||
}
|
}
|
||||||
defer p.Close()
|
defer p.Close()
|
||||||
s, err := p.oidcEndpoints()
|
s, err := p.oidcEndpoints(context.Background())
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", s.AuthorizationEndpoint)
|
assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", s.AuthorizationEndpoint)
|
||||||
assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/token", s.TokenEndpoint)
|
assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/token", s.TokenEndpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockGet func(url string) (*http.Response, error)
|
|
||||||
|
|
||||||
func (m mockGet) Get(url string) (*http.Response, error) {
|
|
||||||
return m(url)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOidcForWorkspace(t *testing.T) {
|
func TestOidcForWorkspace(t *testing.T) {
|
||||||
p := &PersistentAuth{
|
p := &PersistentAuth{
|
||||||
Host: "abc",
|
Host: "abc",
|
||||||
http: mockGet(func(url string) (*http.Response, error) {
|
http: httpclient.NewApiClient(httpclient.ClientConfig{
|
||||||
assert.Equal(t, "https://abc/oidc/.well-known/oauth-authorization-server", url)
|
Transport: fixtures.MappingTransport{
|
||||||
return &http.Response{
|
"GET /oidc/.well-known/oauth-authorization-server": {
|
||||||
StatusCode: 200,
|
Status: 200,
|
||||||
Body: io.NopCloser(strings.NewReader(`{
|
Response: map[string]string{
|
||||||
"authorization_endpoint": "a",
|
"authorization_endpoint": "a",
|
||||||
"token_endpoint": "b"
|
"token_endpoint": "b",
|
||||||
}`)),
|
},
|
||||||
}, nil
|
},
|
||||||
|
},
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
defer p.Close()
|
defer p.Close()
|
||||||
endpoints, err := p.oidcEndpoints()
|
endpoints, err := p.oidcEndpoints(context.Background())
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "a", endpoints.AuthorizationEndpoint)
|
assert.Equal(t, "a", endpoints.AuthorizationEndpoint)
|
||||||
assert.Equal(t, "b", endpoints.TokenEndpoint)
|
assert.Equal(t, "b", endpoints.TokenEndpoint)
|
||||||
|
|
Loading…
Reference in New Issue