databricks-cli/libs/auth/oauth.go

291 lines
7.8 KiB
Go
Raw Permalink Normal View History

package auth
import (
"context"
Upgraded Go version to 1.21 (#664) ## Changes Upgraded Go version to 1.21 Upgraded to use `slices` and `slog` from core instead of experimental. Still use `exp/maps` as our code relies on `maps.Keys` which is not part of core package and therefore refactoring required. ### Tests Integration tests passed ``` [DEBUG] Test execution command: /opt/homebrew/opt/go@1.21/bin/go test ./... -json -timeout 1h -run ^TestAcc [DEBUG] Test execution directory: /Users/andrew.nester/cli 2023/08/15 13:20:51 [INFO] ✅ TestAccAlertsCreateErrWhenNoArguments (2.150s) 2023/08/15 13:20:52 [INFO] ✅ TestAccApiGet (0.580s) 2023/08/15 13:20:53 [INFO] ✅ TestAccClustersList (0.900s) 2023/08/15 13:20:54 [INFO] ✅ TestAccClustersGet (0.870s) 2023/08/15 13:21:06 [INFO] ✅ TestAccFilerWorkspaceFilesReadWrite (11.980s) 2023/08/15 13:21:13 [INFO] ✅ TestAccFilerWorkspaceFilesReadDir (7.060s) 2023/08/15 13:21:25 [INFO] ✅ TestAccFilerDbfsReadWrite (12.810s) 2023/08/15 13:21:33 [INFO] ✅ TestAccFilerDbfsReadDir (7.380s) 2023/08/15 13:21:41 [INFO] ✅ TestAccFilerWorkspaceNotebookConflict (7.760s) 2023/08/15 13:21:49 [INFO] ✅ TestAccFilerWorkspaceNotebookWithOverwriteFlag (8.660s) 2023/08/15 13:21:49 [INFO] ✅ TestAccFilerLocalReadWrite (0.020s) 2023/08/15 13:21:49 [INFO] ✅ TestAccFilerLocalReadDir (0.010s) 2023/08/15 13:21:52 [INFO] ✅ TestAccFsCatForDbfs (3.190s) 2023/08/15 13:21:53 [INFO] ✅ TestAccFsCatForDbfsOnNonExistentFile (0.890s) 2023/08/15 13:21:54 [INFO] ✅ TestAccFsCatForDbfsInvalidScheme (0.600s) 2023/08/15 13:21:57 [INFO] ✅ TestAccFsCatDoesNotSupportOutputModeJson (2.960s) 2023/08/15 13:22:28 [INFO] ✅ TestAccFsCpDir (31.480s) 2023/08/15 13:22:43 [INFO] ✅ TestAccFsCpFileToFile (14.530s) 2023/08/15 13:22:58 [INFO] ✅ TestAccFsCpFileToDir (14.610s) 2023/08/15 13:23:29 [INFO] ✅ TestAccFsCpDirToDirFileNotOverwritten (31.810s) 2023/08/15 13:23:47 [INFO] ✅ TestAccFsCpFileToDirFileNotOverwritten (17.500s) 2023/08/15 13:24:04 [INFO] ✅ TestAccFsCpFileToFileFileNotOverwritten (17.260s) 2023/08/15 13:24:37 [INFO] ✅ TestAccFsCpDirToDirWithOverwriteFlag (32.690s) 2023/08/15 13:24:56 [INFO] ✅ TestAccFsCpFileToFileWithOverwriteFlag (19.290s) 2023/08/15 13:25:15 [INFO] ✅ TestAccFsCpFileToDirWithOverwriteFlag (19.230s) 2023/08/15 13:25:17 [INFO] ✅ TestAccFsCpErrorsWhenSourceIsDirWithoutRecursiveFlag (2.010s) 2023/08/15 13:25:18 [INFO] ✅ TestAccFsCpErrorsOnInvalidScheme (0.610s) 2023/08/15 13:25:33 [INFO] ✅ TestAccFsCpSourceIsDirectoryButTargetIsFile (14.900s) 2023/08/15 13:25:37 [INFO] ✅ TestAccFsLsForDbfs (3.770s) 2023/08/15 13:25:41 [INFO] ✅ TestAccFsLsForDbfsWithAbsolutePaths (4.160s) 2023/08/15 13:25:44 [INFO] ✅ TestAccFsLsForDbfsOnFile (2.990s) 2023/08/15 13:25:46 [INFO] ✅ TestAccFsLsForDbfsOnEmptyDir (1.870s) 2023/08/15 13:25:46 [INFO] ✅ TestAccFsLsForDbfsForNonexistingDir (0.850s) 2023/08/15 13:25:47 [INFO] ✅ TestAccFsLsWithoutScheme (0.560s) 2023/08/15 13:25:49 [INFO] ✅ TestAccFsMkdirCreatesDirectory (2.310s) 2023/08/15 13:25:52 [INFO] ✅ TestAccFsMkdirCreatesMultipleDirectories (2.920s) 2023/08/15 13:25:55 [INFO] ✅ TestAccFsMkdirWhenDirectoryAlreadyExists (2.320s) 2023/08/15 13:25:57 [INFO] ✅ TestAccFsMkdirWhenFileExistsAtPath (2.820s) 2023/08/15 13:26:01 [INFO] ✅ TestAccFsRmForFile (4.030s) 2023/08/15 13:26:05 [INFO] ✅ TestAccFsRmForEmptyDirectory (3.530s) 2023/08/15 13:26:08 [INFO] ✅ TestAccFsRmForNonEmptyDirectory (3.190s) 2023/08/15 13:26:09 [INFO] ✅ TestAccFsRmForNonExistentFile (0.830s) 2023/08/15 13:26:13 [INFO] ✅ TestAccFsRmForNonEmptyDirectoryWithRecursiveFlag (3.580s) 2023/08/15 13:26:13 [INFO] ✅ TestAccGitClone (0.800s) 2023/08/15 13:26:14 [INFO] ✅ TestAccGitCloneWithOnlyRepoNameOnAlternateBranch (0.790s) 2023/08/15 13:26:15 [INFO] ✅ TestAccGitCloneErrorsWhenRepositoryDoesNotExist (0.540s) 2023/08/15 13:26:23 [INFO] ✅ TestAccLock (8.630s) 2023/08/15 13:26:27 [INFO] ✅ TestAccLockUnlockWithoutAllowsLockFileNotExist (3.490s) 2023/08/15 13:26:30 [INFO] ✅ TestAccLockUnlockWithAllowsLockFileNotExist (3.130s) 2023/08/15 13:26:39 [INFO] ✅ TestAccSyncFullFileSync (9.370s) 2023/08/15 13:26:50 [INFO] ✅ TestAccSyncIncrementalFileSync (10.390s) 2023/08/15 13:27:00 [INFO] ✅ TestAccSyncNestedFolderSync (10.680s) 2023/08/15 13:27:11 [INFO] ✅ TestAccSyncNestedFolderDoesntFailOnNonEmptyDirectory (10.970s) 2023/08/15 13:27:22 [INFO] ✅ TestAccSyncNestedSpacePlusAndHashAreEscapedSync (10.930s) 2023/08/15 13:27:29 [INFO] ✅ TestAccSyncIncrementalFileOverwritesFolder (7.020s) 2023/08/15 13:27:37 [INFO] ✅ TestAccSyncIncrementalSyncPythonNotebookToFile (7.380s) 2023/08/15 13:27:43 [INFO] ✅ TestAccSyncIncrementalSyncFileToPythonNotebook (6.050s) 2023/08/15 13:27:48 [INFO] ✅ TestAccSyncIncrementalSyncPythonNotebookDelete (5.390s) 2023/08/15 13:27:51 [INFO] ✅ TestAccSyncEnsureRemotePathIsUsableIfRepoDoesntExist (2.570s) 2023/08/15 13:27:56 [INFO] ✅ TestAccSyncEnsureRemotePathIsUsableIfRepoExists (5.540s) 2023/08/15 13:27:58 [INFO] ✅ TestAccSyncEnsureRemotePathIsUsableInWorkspace (1.840s) 2023/08/15 13:27:59 [INFO] ✅ TestAccWorkspaceList (0.790s) 2023/08/15 13:28:08 [INFO] ✅ TestAccExportDir (8.860s) 2023/08/15 13:28:11 [INFO] ✅ TestAccExportDirDoesNotOverwrite (3.090s) 2023/08/15 13:28:14 [INFO] ✅ TestAccExportDirWithOverwriteFlag (3.500s) 2023/08/15 13:28:23 [INFO] ✅ TestAccImportDir (8.330s) 2023/08/15 13:28:34 [INFO] ✅ TestAccImportDirDoesNotOverwrite (10.970s) 2023/08/15 13:28:44 [INFO] ✅ TestAccImportDirWithOverwriteFlag (10.130s) 2023/08/15 13:28:44 [INFO] ✅ 68/68 passed, 0 failed, 3 skipped ```
2023-08-15 13:50:40 +00:00
"crypto/rand"
"crypto/sha256"
_ "embed"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
"github.com/databricks/cli/libs/auth/cache"
2024-03-11 20:58:07 +00:00
"github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/retries"
"github.com/pkg/browser"
"golang.org/x/oauth2"
"golang.org/x/oauth2/authhandler"
)
const (
// these values are predefined by Databricks as a public client
// and is specific to this application only. Using these values
// for other applications is not allowed.
appClientID = "databricks-cli"
appRedirectAddr = "localhost:8020"
// maximum amount of time to acquire listener on appRedirectAddr
DefaultTimeout = 45 * time.Second
)
var ( // Databricks SDK API: `databricks OAuth is not` will be checked for presence
ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host")
ErrNotConfigured = errors.New("databricks OAuth is not configured for this host")
ErrFetchCredentials = errors.New("cannot fetch credentials")
)
type PersistentAuth struct {
Host string
AccountID string
2024-03-11 20:58:07 +00:00
Profile string
http httpGet
cache tokenCache
ln net.Listener
browser func(string) error
}
type httpGet interface {
Get(string) (*http.Response, error)
}
type tokenCache interface {
Store(key string, t *oauth2.Token) error
Lookup(key string) (*oauth2.Token, error)
}
func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) {
err := a.init(ctx)
if err != nil {
return nil, fmt.Errorf("init: %w", err)
}
// lookup token identified by host (and possibly the account id)
key := a.key()
t, err := a.cache.Lookup(key)
if err != nil {
return nil, fmt.Errorf("cache: %w", err)
}
// early return for valid tokens
if t.Valid() {
// do not print refresh token to end-user
t.RefreshToken = ""
return t, nil
}
// OAuth2 config is invoked only for expired tokens to speed up
// the happy path in the token retrieval
cfg, err := a.oauth2Config()
if err != nil {
return nil, err
}
// eagerly refresh token
2024-03-11 20:58:07 +00:00
ctx = context.WithValue(ctx, oauth2.HTTPClient, a.http)
refreshed, err := cfg.TokenSource(ctx, t).Token()
if err != nil {
return nil, fmt.Errorf("token refresh: %w", err)
}
err = a.cache.Store(key, refreshed)
if err != nil {
return nil, fmt.Errorf("cache refresh: %w", err)
}
// do not print refresh token to end-user
refreshed.RefreshToken = ""
return refreshed, nil
}
func (a *PersistentAuth) ProfileName() string {
2024-03-11 20:58:07 +00:00
if a.Profile != "" {
return a.Profile
}
if a.AccountID != "" {
return fmt.Sprintf("ACCOUNT-%s", a.AccountID)
}
host := strings.TrimPrefix(a.Host, "https://")
split := strings.Split(host, ".")
return split[0]
}
func (a *PersistentAuth) Challenge(ctx context.Context) error {
err := a.init(ctx)
if err != nil {
return fmt.Errorf("init: %w", err)
}
cfg, err := a.oauth2Config()
if err != nil {
return err
}
cb, err := newCallback(ctx, a)
if err != nil {
return fmt.Errorf("callback server: %w", err)
}
defer cb.Close()
state, pkce := a.stateAndPKCE()
ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce)
t, err := ts.Token()
if err != nil {
return fmt.Errorf("authorize: %w", err)
}
// cache token identified by host (and possibly the account id)
err = a.cache.Store(a.key(), t)
if err != nil {
return fmt.Errorf("store: %w", err)
}
return nil
}
func (a *PersistentAuth) init(ctx context.Context) error {
if a.Host == "" && a.AccountID == "" {
return ErrFetchCredentials
}
if a.http == nil {
2024-03-11 20:58:07 +00:00
c := &config.Config{
Profile: a.Profile,
Host: a.Host,
AccountID: a.AccountID,
}
c.EnsureResolved()
clientConfig := httpclient.ClientConfig{
DebugHeaders: c.DebugHeaders,
DebugTruncateBytes: c.DebugTruncateBytes,
InsecureSkipVerify: c.InsecureSkipVerify,
CABundle: c.CABundle,
RetryTimeout: time.Duration(c.RetryTimeoutSeconds) * time.Second,
HTTPTimeout: time.Duration(c.HTTPTimeoutSeconds) * time.Second,
}
httpClient, err := httpclient.NewHttpClient(clientConfig)
if err != nil {
return err
}
a.http = httpClient
}
if a.cache == nil {
a.cache = &cache.TokenCache{}
}
if a.browser == nil {
a.browser = browser.OpenURL
}
// try acquire listener, which we also use as a machine-local
// exclusive lock to prevent token cache corruption in the scope
// of developer machine, where this command runs.
listener, err := retries.Poll(ctx, DefaultTimeout,
func() (*net.Listener, *retries.Err) {
var lc net.ListenConfig
l, err := lc.Listen(ctx, "tcp", appRedirectAddr)
if err != nil {
return nil, retries.Continue(err)
}
return &l, nil
})
if err != nil {
return fmt.Errorf("listener: %w", err)
}
a.ln = *listener
return nil
}
func (a *PersistentAuth) Close() error {
if a.ln == nil {
return nil
}
return a.ln.Close()
}
func (a *PersistentAuth) oidcEndpoints() (*oauthAuthorizationServer, error) {
prefix := a.key()
if a.AccountID != "" {
return &oauthAuthorizationServer{
AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix),
TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix),
}, nil
}
oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", prefix)
oidcResponse, err := a.http.Get(oidc)
if err != nil {
return nil, fmt.Errorf("fetch .well-known: %w", err)
}
if oidcResponse.StatusCode != 200 {
return nil, ErrOAuthNotSupported
}
if oidcResponse.Body == nil {
return nil, fmt.Errorf("fetch .well-known: empty body")
}
defer oidcResponse.Body.Close()
raw, err := io.ReadAll(oidcResponse.Body)
if err != nil {
return nil, fmt.Errorf("read .well-known: %w", err)
}
var oauthEndpoints oauthAuthorizationServer
err = json.Unmarshal(raw, &oauthEndpoints)
if err != nil {
return nil, fmt.Errorf("parse .well-known: %w", err)
}
return &oauthEndpoints, nil
}
func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) {
// in this iteration of CLI, we're using all scopes by default,
// because tools like CLI and Terraform do use all apis. This
// decision may be reconsidered later, once we have a proper
// taxonomy of all scopes ready and implemented.
scopes := []string{
"offline_access",
"all-apis",
}
endpoints, err := a.oidcEndpoints()
if err != nil {
return nil, fmt.Errorf("oidc: %w", err)
}
return &oauth2.Config{
ClientID: appClientID,
Endpoint: oauth2.Endpoint{
AuthURL: endpoints.AuthorizationEndpoint,
TokenURL: endpoints.TokenEndpoint,
AuthStyle: oauth2.AuthStyleInParams,
},
RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr),
Scopes: scopes,
}, nil
}
// key is currently used for two purposes: OIDC URL prefix and token cache key.
// once we decide to start storing scopes in the token cache, we should change
// this approach.
func (a *PersistentAuth) key() string {
a.Host = strings.TrimSuffix(a.Host, "/")
if !strings.HasPrefix(a.Host, "http") {
a.Host = fmt.Sprintf("https://%s", a.Host)
}
if a.AccountID != "" {
return fmt.Sprintf("%s/oidc/accounts/%s", a.Host, a.AccountID)
}
return a.Host
}
func (a *PersistentAuth) stateAndPKCE() (string, *authhandler.PKCEParams) {
verifier := a.randomString(64)
verifierSha256 := sha256.Sum256([]byte(verifier))
challenge := base64.RawURLEncoding.EncodeToString(verifierSha256[:])
return a.randomString(16), &authhandler.PKCEParams{
Challenge: challenge,
ChallengeMethod: "S256",
Verifier: verifier,
}
}
func (a *PersistentAuth) randomString(size int) string {
raw := make([]byte, size)
_, _ = rand.Read(raw)
return base64.RawURLEncoding.EncodeToString(raw)
}
type oauthAuthorizationServer struct {
AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize
TokenEndpoint string `json:"token_endpoint"` // ../v1/token
}