mirror of https://github.com/databricks/cli.git
278 lines
7.6 KiB
Go
278 lines
7.6 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
_ "embed"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/databricks/cli/libs/auth/cache"
|
|
"github.com/databricks/databricks-sdk-go/httpclient"
|
|
"github.com/databricks/databricks-sdk-go/retries"
|
|
"github.com/pkg/browser"
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/oauth2/authhandler"
|
|
)
|
|
|
|
var apiClientForOauth int
|
|
|
|
func WithApiClientForOAuth(ctx context.Context, c *httpclient.ApiClient) context.Context {
|
|
return context.WithValue(ctx, &apiClientForOauth, c)
|
|
}
|
|
|
|
func GetApiClientForOAuth(ctx context.Context) *httpclient.ApiClient {
|
|
c, ok := ctx.Value(&apiClientForOauth).(*httpclient.ApiClient)
|
|
if !ok {
|
|
return httpclient.NewApiClient(httpclient.ClientConfig{})
|
|
}
|
|
return c
|
|
}
|
|
|
|
const (
|
|
// these values are predefined by Databricks as a public client
|
|
// and is specific to this application only. Using these values
|
|
// for other applications is not allowed.
|
|
appClientID = "databricks-cli"
|
|
appRedirectAddr = "localhost:8020"
|
|
|
|
// maximum amount of time to acquire listener on appRedirectAddr
|
|
listenerTimeout = 45 * time.Second
|
|
)
|
|
|
|
var ( // Databricks SDK API: `databricks OAuth is not` will be checked for presence
|
|
ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host")
|
|
ErrNotConfigured = errors.New("databricks OAuth is not configured for this host")
|
|
ErrFetchCredentials = errors.New("cannot fetch credentials")
|
|
)
|
|
|
|
type PersistentAuth struct {
|
|
Host string
|
|
AccountID string
|
|
|
|
http *httpclient.ApiClient
|
|
cache cache.TokenCache
|
|
ln net.Listener
|
|
browser func(string) error
|
|
}
|
|
|
|
func (a *PersistentAuth) SetApiClient(h *httpclient.ApiClient) {
|
|
a.http = h
|
|
}
|
|
|
|
func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) {
|
|
err := a.init(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("init: %w", err)
|
|
}
|
|
// lookup token identified by host (and possibly the account id)
|
|
key := a.key()
|
|
t, err := a.cache.Lookup(key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cache: %w", err)
|
|
}
|
|
// early return for valid tokens
|
|
if t.Valid() {
|
|
// do not print refresh token to end-user
|
|
t.RefreshToken = ""
|
|
return t, nil
|
|
}
|
|
// OAuth2 config is invoked only for expired tokens to speed up
|
|
// the happy path in the token retrieval
|
|
cfg, err := a.oauth2Config(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// make OAuth2 library use our client
|
|
ctx = a.http.InContextForOAuth2(ctx)
|
|
// eagerly refresh token
|
|
refreshed, err := cfg.TokenSource(ctx, t).Token()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("token refresh: %w", err)
|
|
}
|
|
err = a.cache.Store(key, refreshed)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cache refresh: %w", err)
|
|
}
|
|
// do not print refresh token to end-user
|
|
refreshed.RefreshToken = ""
|
|
return refreshed, nil
|
|
}
|
|
|
|
func (a *PersistentAuth) ProfileName() string {
|
|
if a.AccountID != "" {
|
|
return fmt.Sprintf("ACCOUNT-%s", a.AccountID)
|
|
}
|
|
host := strings.TrimPrefix(a.Host, "https://")
|
|
split := strings.Split(host, ".")
|
|
return split[0]
|
|
}
|
|
|
|
func (a *PersistentAuth) Challenge(ctx context.Context) error {
|
|
err := a.init(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("init: %w", err)
|
|
}
|
|
cfg, err := a.oauth2Config(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
cb, err := newCallback(ctx, a)
|
|
if err != nil {
|
|
return fmt.Errorf("callback server: %w", err)
|
|
}
|
|
defer cb.Close()
|
|
state, pkce := a.stateAndPKCE()
|
|
// make OAuth2 library use our client
|
|
ctx = a.http.InContextForOAuth2(ctx)
|
|
ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce)
|
|
t, err := ts.Token()
|
|
if err != nil {
|
|
return fmt.Errorf("authorize: %w", err)
|
|
}
|
|
// cache token identified by host (and possibly the account id)
|
|
err = a.cache.Store(a.key(), t)
|
|
if err != nil {
|
|
return fmt.Errorf("store: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a *PersistentAuth) ClearToken(ctx context.Context) error {
|
|
if a.Host == "" && a.AccountID == "" {
|
|
return ErrFetchCredentials
|
|
}
|
|
if a.cache == nil {
|
|
a.cache = cache.GetTokenCache(ctx)
|
|
}
|
|
// lookup token identified by host (and possibly the account id)
|
|
key := a.key()
|
|
return a.cache.DeleteKey(key)
|
|
}
|
|
|
|
func (a *PersistentAuth) init(ctx context.Context) error {
|
|
if a.Host == "" && a.AccountID == "" {
|
|
return ErrFetchCredentials
|
|
}
|
|
if a.http == nil {
|
|
a.http = GetApiClientForOAuth(ctx)
|
|
}
|
|
if a.cache == nil {
|
|
a.cache = cache.GetTokenCache(ctx)
|
|
}
|
|
if a.browser == nil {
|
|
a.browser = browser.OpenURL
|
|
}
|
|
// try acquire listener, which we also use as a machine-local
|
|
// exclusive lock to prevent token cache corruption in the scope
|
|
// of developer machine, where this command runs.
|
|
listener, err := retries.Poll(ctx, listenerTimeout,
|
|
func() (*net.Listener, *retries.Err) {
|
|
var lc net.ListenConfig
|
|
l, err := lc.Listen(ctx, "tcp", appRedirectAddr)
|
|
if err != nil {
|
|
return nil, retries.Continue(err)
|
|
}
|
|
return &l, nil
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("listener: %w", err)
|
|
}
|
|
a.ln = *listener
|
|
return nil
|
|
}
|
|
|
|
func (a *PersistentAuth) Close() error {
|
|
if a.ln == nil {
|
|
return nil
|
|
}
|
|
return a.ln.Close()
|
|
}
|
|
|
|
func (a *PersistentAuth) oidcEndpoints(ctx context.Context) (*oauthAuthorizationServer, error) {
|
|
prefix := a.key()
|
|
if a.AccountID != "" {
|
|
return &oauthAuthorizationServer{
|
|
AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix),
|
|
TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix),
|
|
}, nil
|
|
}
|
|
var oauthEndpoints oauthAuthorizationServer
|
|
oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", prefix)
|
|
err := a.http.Do(ctx, "GET", oidc, httpclient.WithResponseUnmarshal(&oauthEndpoints))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("fetch .well-known: %w", err)
|
|
}
|
|
var httpErr *httpclient.HttpError
|
|
if errors.As(err, &httpErr) && httpErr.StatusCode != 200 {
|
|
return nil, ErrOAuthNotSupported
|
|
}
|
|
return &oauthEndpoints, nil
|
|
}
|
|
|
|
func (a *PersistentAuth) oauth2Config(ctx context.Context) (*oauth2.Config, error) {
|
|
// in this iteration of CLI, we're using all scopes by default,
|
|
// because tools like CLI and Terraform do use all apis. This
|
|
// decision may be reconsidered later, once we have a proper
|
|
// taxonomy of all scopes ready and implemented.
|
|
scopes := []string{
|
|
"offline_access",
|
|
"all-apis",
|
|
}
|
|
endpoints, err := a.oidcEndpoints(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("oidc: %w", err)
|
|
}
|
|
return &oauth2.Config{
|
|
ClientID: appClientID,
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: endpoints.AuthorizationEndpoint,
|
|
TokenURL: endpoints.TokenEndpoint,
|
|
AuthStyle: oauth2.AuthStyleInParams,
|
|
},
|
|
RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr),
|
|
Scopes: scopes,
|
|
}, nil
|
|
}
|
|
|
|
// key is currently used for two purposes: OIDC URL prefix and token cache key.
|
|
// once we decide to start storing scopes in the token cache, we should change
|
|
// this approach.
|
|
func (a *PersistentAuth) key() string {
|
|
a.Host = strings.TrimSuffix(a.Host, "/")
|
|
if !strings.HasPrefix(a.Host, "http") {
|
|
a.Host = fmt.Sprintf("https://%s", a.Host)
|
|
}
|
|
if a.AccountID != "" {
|
|
return fmt.Sprintf("%s/oidc/accounts/%s", a.Host, a.AccountID)
|
|
}
|
|
return a.Host
|
|
}
|
|
|
|
func (a *PersistentAuth) stateAndPKCE() (string, *authhandler.PKCEParams) {
|
|
verifier := a.randomString(64)
|
|
verifierSha256 := sha256.Sum256([]byte(verifier))
|
|
challenge := base64.RawURLEncoding.EncodeToString(verifierSha256[:])
|
|
return a.randomString(16), &authhandler.PKCEParams{
|
|
Challenge: challenge,
|
|
ChallengeMethod: "S256",
|
|
Verifier: verifier,
|
|
}
|
|
}
|
|
|
|
func (a *PersistentAuth) randomString(size int) string {
|
|
raw := make([]byte, size)
|
|
_, _ = rand.Read(raw)
|
|
return base64.RawURLEncoding.EncodeToString(raw)
|
|
}
|
|
|
|
type oauthAuthorizationServer struct {
|
|
AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize
|
|
TokenEndpoint string `json:"token_endpoint"` // ../v1/token
|
|
}
|