Added `bricks auth login` and `bricks auth token` (#158)

# Auth challenge (happy path)

Simplified description of [PKCE](https://oauth.net/2/pkce/)
implementation:

```mermaid
sequenceDiagram
    autonumber
    actor User
    
    User ->> CLI: type `bricks auth login HOST`
    CLI ->>+ HOST: request OIDC endpoints
    HOST ->>- CLI: auth & token endpoints
    CLI ->> CLI: start embedded server to consume redirects (lock)
    CLI -->>+ Auth Endpoint: open browser with RND1 + SHA256(RND2)

    User ->>+ Auth Endpoint: Go through SSO
    Auth Endpoint ->>- CLI: AUTH CODE + 'RND1 (redirect)

    CLI ->>+ Token Endpoint: Exchange: AUTH CODE + RND2
    Token Endpoint ->>- CLI: Access Token (JWT) + refresh + expiry
    CLI ->> Token cache: Save Access Token (JWT) + refresh + expiry
    CLI ->> User: success
```

# Token refresh (happy path)

```mermaid
sequenceDiagram
    autonumber
    actor User
    
    User ->> CLI: type `bricks token HOST`
    
    CLI ->> CLI: acquire lock (same local addr as redirect server)
    CLI ->>+ Token cache: read token

    critical token not expired
    Token cache ->>- User: JWT (without refresh)

    option token is expired
    CLI ->>+ HOST: request OIDC endpoints
    HOST ->>- CLI: auth & token endpoints
    CLI ->>+ Token Endpoint: refresh token
    Token Endpoint ->>- CLI: JWT (refreshed)
    CLI ->> Token cache: save JWT (refreshed)
    CLI ->> User: JWT (refreshed)
    
    option no auth for host
    CLI -X User: no auth configured
    end
```
This commit is contained in:
Serge Smertin 2023-01-06 16:15:57 +01:00 committed by GitHub
parent a59136f77f
commit b87b4b0f40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1343 additions and 2 deletions

5
.vscode/extensions.json vendored Normal file
View File

@ -0,0 +1,5 @@
{
"recommendations": [
"bierner.markdown-mermaid"
]
}

51
cmd/auth/README.md Normal file
View File

@ -0,0 +1,51 @@
# Auth challenge (happy path)
Simplified description of [PKCE](https://oauth.net/2/pkce/) implementation:
```mermaid
sequenceDiagram
autonumber
actor User
User ->> CLI: type `bricks auth login HOST`
CLI ->>+ HOST: request OIDC endpoints
HOST ->>- CLI: auth & token endpoints
CLI ->> CLI: start embedded server to consume redirects (lock)
CLI -->>+ Auth Endpoint: open browser with RND1 + SHA256(RND2)
User ->>+ Auth Endpoint: Go through SSO
Auth Endpoint ->>- CLI: AUTH CODE + 'RND1 (redirect)
CLI ->>+ Token Endpoint: Exchange: AUTH CODE + RND2
Token Endpoint ->>- CLI: Access Token (JWT) + refresh + expiry
CLI ->> Token cache: Save Access Token (JWT) + refresh + expiry
CLI ->> User: success
```
# Token refresh (happy path)
```mermaid
sequenceDiagram
autonumber
actor User
User ->> CLI: type `bricks token HOST`
CLI ->> CLI: acquire lock (same local addr as redirect server)
CLI ->>+ Token cache: read token
critical token not expired
Token cache ->>- User: JWT (without refresh)
option token is expired
CLI ->>+ HOST: request OIDC endpoints
HOST ->>- CLI: auth & token endpoints
CLI ->>+ Token Endpoint: refresh token
Token Endpoint ->>- CLI: JWT (refreshed)
CLI ->> Token cache: save JWT (refreshed)
CLI ->> User: JWT (refreshed)
option no auth for host
CLI -X User: no auth configured
end
```

20
cmd/auth/auth.go Normal file
View File

@ -0,0 +1,20 @@
package auth
import (
"github.com/databricks/bricks/cmd/root"
"github.com/databricks/bricks/libs/auth"
"github.com/spf13/cobra"
)
var authCmd = &cobra.Command{
Use: "auth",
Short: "Authentication related commands",
}
var perisistentAuth auth.PersistentAuth
func init() {
root.RootCmd.AddCommand(authCmd)
authCmd.PersistentFlags().StringVar(&perisistentAuth.Host, "host", perisistentAuth.Host, "Databricks Host")
authCmd.PersistentFlags().StringVar(&perisistentAuth.AccountID, "account-id", perisistentAuth.AccountID, "Databricks Account ID")
}

143
cmd/auth/env.go Normal file
View File

@ -0,0 +1,143 @@
package auth
import (
"encoding/json"
"errors"
"fmt"
"io/fs"
"net/http"
"net/url"
"strings"
"github.com/databricks/databricks-sdk-go/config"
"github.com/spf13/cobra"
"gopkg.in/ini.v1"
)
func canonicalHost(host string) (string, error) {
parsedHost, err := url.Parse(host)
if err != nil {
return "", err
}
// If the host is empty, assume the scheme wasn't included.
if parsedHost.Host == "" {
return fmt.Sprintf("https://%s", host), nil
}
return fmt.Sprintf("https://%s", parsedHost.Host), nil
}
var ErrNoMatchingProfiles = errors.New("no matching profiles found")
func resolveSection(cfg *config.Config, iniFile *ini.File) (*ini.Section, error) {
var candidates []*ini.Section
configuredHost, err := canonicalHost(cfg.Host)
if err != nil {
return nil, err
}
for _, section := range iniFile.Sections() {
hash := section.KeysHash()
host, ok := hash["host"]
if !ok {
// if host is not set
continue
}
canonical, err := canonicalHost(host)
if err != nil {
// we're fine with other corrupt profiles
continue
}
if canonical != configuredHost {
continue
}
candidates = append(candidates, section)
}
if len(candidates) == 0 {
return nil, ErrNoMatchingProfiles
}
// in the real situations, we don't expect this to happen often
// (if not at all), hence we don't trim the list
if len(candidates) > 1 {
var profiles []string
for _, v := range candidates {
profiles = append(profiles, v.Name())
}
return nil, fmt.Errorf("%s match %s in %s",
strings.Join(profiles, " and "), cfg.Host, cfg.ConfigFile)
}
return candidates[0], nil
}
func loadFromDatabricksCfg(cfg *config.Config) error {
iniFile, err := getDatabricksCfg()
if errors.Is(err, fs.ErrNotExist) {
// it's fine not to have ~/.databrickscfg
return nil
}
if err != nil {
return err
}
profile, err := resolveSection(cfg, iniFile)
if err == ErrNoMatchingProfiles {
// it's also fine for Azure CLI or Bricks CLI, which
// are resolved by unified auth handling in the Go SDK.
return nil
}
if err != nil {
return err
}
cfg.Profile = profile.Name()
return nil
}
var envCmd = &cobra.Command{
Use: "env",
Short: "Get env",
RunE: func(cmd *cobra.Command, args []string) error {
cfg := &config.Config{
Host: host,
Profile: profile,
}
if profile != "" {
cfg.Profile = profile
} else if cfg.Host == "" {
cfg.Profile = "DEFAULT"
} else if err := loadFromDatabricksCfg(cfg); err != nil {
return err
}
// Go SDK is lazy loaded because of Terraform semantics,
// so we're creating a dummy HTTP request as a placeholder
// for headers.
r := &http.Request{Header: http.Header{}}
err := cfg.Authenticate(r.WithContext(cmd.Context()))
if err != nil {
return err
}
vars := map[string]string{}
for _, a := range config.ConfigAttributes {
if a.IsZero(cfg) {
continue
}
envValue := a.GetString(cfg)
for _, envName := range a.EnvVars {
vars[envName] = envValue
}
}
raw, err := json.MarshalIndent(map[string]any{
"env": vars,
}, "", " ")
if err != nil {
return err
}
cmd.OutOrStdout().Write(raw)
return nil
},
}
var host string
var profile string
func init() {
authCmd.AddCommand(envCmd)
envCmd.Flags().StringVar(&host, "host", host, "Hostname to get auth env for")
envCmd.Flags().StringVar(&profile, "profile", profile, "Profile to get auth env for")
}

31
cmd/auth/login.go Normal file
View File

@ -0,0 +1,31 @@
package auth
import (
"context"
"time"
"github.com/databricks/bricks/libs/auth"
"github.com/spf13/cobra"
)
var loginTimeout time.Duration
var loginCmd = &cobra.Command{
Use: "login [HOST]",
Short: "Authenticate this machine",
RunE: func(cmd *cobra.Command, args []string) error {
if perisistentAuth.Host == "" && len(args) == 1 {
perisistentAuth.Host = args[0]
}
defer perisistentAuth.Close()
ctx, cancel := context.WithTimeout(cmd.Context(), loginTimeout)
defer cancel()
return perisistentAuth.Challenge(ctx)
},
}
func init() {
authCmd.AddCommand(loginCmd)
loginCmd.Flags().DurationVar(&loginTimeout, "timeout", auth.DefaultTimeout,
"Timeout for completing login challenge in the browser")
}

131
cmd/auth/profiles.go Normal file
View File

@ -0,0 +1,131 @@
package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/config"
"github.com/spf13/cobra"
"gopkg.in/ini.v1"
)
func getDatabricksCfg() (*ini.File, error) {
configFile := os.Getenv("DATABRICKS_CONFIG_FILE")
if configFile == "" {
configFile = "~/.databrickscfg"
}
if strings.HasPrefix(configFile, "~") {
homedir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("cannot find homedir: %w", err)
}
configFile = filepath.Join(homedir, configFile[1:])
}
return ini.Load(configFile)
}
type profileMetadata struct {
Name string `json:"name"`
Host string `json:"host,omitempty"`
AccountID string `json:"account_id,omitempty"`
Cloud string `json:"cloud"`
AuthType string `json:"auth_type"`
Valid bool `json:"valid"`
}
func (c *profileMetadata) IsEmpty() bool {
return c.Host == "" && c.AccountID == ""
}
func (c *profileMetadata) Load(ctx context.Context) {
// TODO: disable config loaders other than configfile
cfg := &config.Config{Profile: c.Name}
_ = cfg.EnsureResolved()
if cfg.IsAws() {
c.Cloud = "aws"
} else if cfg.IsAzure() {
c.Cloud = "azure"
} else if cfg.IsGcp() {
c.Cloud = "gcp"
}
if cfg.IsAccountClient() {
a, err := databricks.NewAccountClient((*databricks.Config)(cfg))
if err != nil {
return
}
_, err = a.Workspaces.List(ctx)
c.AuthType = cfg.AuthType
if err != nil {
return
}
c.Valid = true
} else {
w, err := databricks.NewWorkspaceClient((*databricks.Config)(cfg))
if err != nil {
return
}
_, err = w.Tokens.ListAll(ctx)
c.AuthType = cfg.AuthType
if err != nil {
return
}
c.Valid = true
}
// set host again, this time normalized
c.Host = cfg.Host
}
var profilesCmd = &cobra.Command{
Use: "profiles",
Short: "Lists profiles from ~/.databrickscfg",
RunE: func(cmd *cobra.Command, args []string) error {
iniFile, err := getDatabricksCfg()
if os.IsNotExist(err) {
// early return for non-configured machines
return errors.New("~/.databrickcfg not found on current host")
}
if err != nil {
return fmt.Errorf("cannot parse config file: %w", err)
}
var profiles []*profileMetadata
var wg sync.WaitGroup
for _, v := range iniFile.Sections() {
hash := v.KeysHash()
profile := &profileMetadata{
Name: v.Name(),
Host: hash["host"],
AccountID: hash["account_id"],
}
if profile.IsEmpty() {
continue
}
wg.Add(1)
go func() {
// load more information about profile
profile.Load(cmd.Context())
wg.Done()
}()
profiles = append(profiles, profile)
}
wg.Wait()
raw, err := json.MarshalIndent(map[string]any{
"profiles": profiles,
}, "", " ")
if err != nil {
return err
}
cmd.OutOrStdout().Write(raw)
return nil
},
}
func init() {
authCmd.AddCommand(profilesCmd)
}

41
cmd/auth/token.go Normal file
View File

@ -0,0 +1,41 @@
package auth
import (
"context"
"encoding/json"
"time"
"github.com/databricks/bricks/libs/auth"
"github.com/spf13/cobra"
)
var tokenTimeout time.Duration
var tokenCmd = &cobra.Command{
Use: "token [HOST]",
Short: "Get authentication token",
RunE: func(cmd *cobra.Command, args []string) error {
if perisistentAuth.Host == "" && len(args) == 1 {
perisistentAuth.Host = args[0]
}
defer perisistentAuth.Close()
ctx, cancel := context.WithTimeout(cmd.Context(), tokenTimeout)
defer cancel()
t, err := perisistentAuth.Load(ctx)
if err != nil {
return err
}
raw, err := json.MarshalIndent(t, "", " ")
if err != nil {
return err
}
cmd.OutOrStdout().Write(raw)
return nil
},
}
func init() {
authCmd.AddCommand(tokenCmd)
tokenCmd.Flags().DurationVar(&tokenTimeout, "timeout", auth.DefaultTimeout,
"Timeout for acquiring a token.")
}

4
go.mod
View File

@ -48,9 +48,9 @@ require (
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
go.opencensus.io v0.24.0 // indirect go.opencensus.io v0.24.0 // indirect
golang.org/x/net v0.1.0 // indirect golang.org/x/net v0.1.0 // indirect
golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 // indirect golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783
golang.org/x/sys v0.1.0 // indirect golang.org/x/sys v0.1.0 // indirect
golang.org/x/text v0.5.0 // indirect golang.org/x/text v0.5.0
golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac // indirect golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac // indirect
google.golang.org/api v0.105.0 // indirect google.golang.org/api v0.105.0 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect

106
libs/auth/cache/cache.go vendored Normal file
View File

@ -0,0 +1,106 @@
package cache
import (
"encoding/json"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"golang.org/x/oauth2"
)
const (
// where the token cache is stored
tokenCacheFile = ".databricks/token-cache.json"
// only the owner of the file has full execute, read, and write access
ownerExecReadWrite = 0o700
// only the owner of the file has full read and write access
ownerReadWrite = 0o600
// format versioning leaves some room for format improvement
tokenCacheVersion = 1
)
var ErrNotConfigured = errors.New("databricks OAuth is not configured for this host")
// this implementation requires the calling code to do a machine-wide lock,
// otherwise the file might get corrupt.
type TokenCache struct {
Version int `json:"version"`
Tokens map[string]*oauth2.Token `json:"tokens"`
fileLocation string
}
func (c *TokenCache) Store(key string, t *oauth2.Token) error {
err := c.load()
if errors.Is(err, fs.ErrNotExist) {
dir := filepath.Dir(c.fileLocation)
err = os.MkdirAll(dir, ownerExecReadWrite)
if err != nil {
return fmt.Errorf("mkdir: %w", err)
}
} else if err != nil {
return fmt.Errorf("load: %w", err)
}
c.Version = tokenCacheVersion
if c.Tokens == nil {
c.Tokens = map[string]*oauth2.Token{}
}
c.Tokens[key] = t
raw, err := json.MarshalIndent(c, "", " ")
if err != nil {
return fmt.Errorf("marshal: %w", err)
}
return os.WriteFile(c.fileLocation, raw, ownerReadWrite)
}
func (c *TokenCache) Lookup(key string) (*oauth2.Token, error) {
err := c.load()
if errors.Is(err, fs.ErrNotExist) {
return nil, ErrNotConfigured
} else if err != nil {
return nil, fmt.Errorf("load: %w", err)
}
t, ok := c.Tokens[key]
if !ok {
return nil, ErrNotConfigured
}
return t, nil
}
func (c *TokenCache) location() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("home: %w", err)
}
return filepath.Join(home, tokenCacheFile), nil
}
func (c *TokenCache) load() error {
loc, err := c.location()
if err != nil {
return err
}
c.fileLocation = loc
raw, err := os.ReadFile(loc)
if err != nil {
return fmt.Errorf("read: %w", err)
}
err = json.Unmarshal(raw, c)
if err != nil {
return fmt.Errorf("parse: %w", err)
}
if c.Version != tokenCacheVersion {
// in the later iterations we could do state upgraders,
// so that we transform token cache from v1 to v2 without
// losing the tokens and asking the user to re-authenticate.
return fmt.Errorf("needs version %d, got version %d",
tokenCacheVersion, c.Version)
}
return nil
}

105
libs/auth/cache/cache_test.go vendored Normal file
View File

@ -0,0 +1,105 @@
package cache
import (
"os"
"path/filepath"
"runtime"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)
var homeEnvVar = "HOME"
func init() {
if runtime.GOOS == "windows" {
homeEnvVar = "USERPROFILE"
}
}
func setup(t *testing.T) string {
tempHomeDir := t.TempDir()
t.Setenv(homeEnvVar, tempHomeDir)
return tempHomeDir
}
func TestStoreAndLookup(t *testing.T) {
setup(t)
c := &TokenCache{}
err := c.Store("x", &oauth2.Token{
AccessToken: "abc",
})
require.NoError(t, err)
err = c.Store("y", &oauth2.Token{
AccessToken: "bcd",
})
require.NoError(t, err)
l := &TokenCache{}
tok, err := l.Lookup("x")
require.NoError(t, err)
assert.Equal(t, "abc", tok.AccessToken)
assert.Equal(t, 2, len(l.Tokens))
_, err = l.Lookup("z")
assert.Equal(t, ErrNotConfigured, err)
}
func TestNoCacheFileReturnsErrNotConfigured(t *testing.T) {
setup(t)
l := &TokenCache{}
_, err := l.Lookup("x")
assert.Equal(t, ErrNotConfigured, err)
}
func TestLoadCorruptFile(t *testing.T) {
home := setup(t)
f := filepath.Join(home, tokenCacheFile)
err := os.MkdirAll(filepath.Dir(f), ownerExecReadWrite)
require.NoError(t, err)
err = os.WriteFile(f, []byte("abc"), ownerExecReadWrite)
require.NoError(t, err)
l := &TokenCache{}
_, err = l.Lookup("x")
assert.EqualError(t, err, "load: parse: invalid character 'a' looking for beginning of value")
}
func TestLoadWrongVersion(t *testing.T) {
home := setup(t)
f := filepath.Join(home, tokenCacheFile)
err := os.MkdirAll(filepath.Dir(f), ownerExecReadWrite)
require.NoError(t, err)
err = os.WriteFile(f, []byte(`{"version": 823, "things": []}`), ownerExecReadWrite)
require.NoError(t, err)
l := &TokenCache{}
_, err = l.Lookup("x")
assert.EqualError(t, err, "load: needs version 1, got version 823")
}
func TestDevNull(t *testing.T) {
t.Setenv(homeEnvVar, "/dev/null")
l := &TokenCache{}
_, err := l.Lookup("x")
// macOS/Linux: load: read: open /dev/null/.databricks/token-cache.json:
// windows: databricks OAuth is not configured for this host
assert.Error(t, err)
}
func TestStoreOnDev(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
t.Setenv(homeEnvVar, "/dev")
c := &TokenCache{}
err := c.Store("x", &oauth2.Token{
AccessToken: "abc",
})
// Linux: permission denied
// macOS: read-only file system
assert.Error(t, err)
}

102
libs/auth/callback.go Normal file
View File

@ -0,0 +1,102 @@
package auth
import (
"context"
_ "embed"
"fmt"
"html/template"
"net"
"net/http"
"strings"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
//go:embed page.tmpl
var pageTmpl string
type oauthResult struct {
Error string
ErrorDescription string
Host string
State string
Code string
}
type callbackServer struct {
ln net.Listener
srv http.Server
ctx context.Context
a *PersistentAuth
renderErrCh chan error
feedbackCh chan oauthResult
tmpl *template.Template
}
func newCallback(ctx context.Context, a *PersistentAuth) (*callbackServer, error) {
tmpl, err := template.New("page").Funcs(template.FuncMap{
"title": func(in string) string {
title := cases.Title(language.English)
return title.String(strings.ReplaceAll(in, "_", " "))
},
}).Parse(pageTmpl)
if err != nil {
return nil, err
}
cb := &callbackServer{
feedbackCh: make(chan oauthResult),
renderErrCh: make(chan error),
tmpl: tmpl,
ctx: ctx,
ln: a.ln,
a: a,
}
cb.srv.Handler = cb
go cb.srv.Serve(cb.ln)
return cb, nil
}
func (cb *callbackServer) Close() error {
return cb.srv.Close()
}
// ServeHTTP renders page.html template
func (cb *callbackServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
res := oauthResult{
Error: r.FormValue("error"),
ErrorDescription: r.FormValue("error_description"),
Code: r.FormValue("code"),
State: r.FormValue("state"),
Host: cb.a.Host,
}
if res.Error != "" {
w.WriteHeader(http.StatusBadRequest)
} else {
w.WriteHeader(http.StatusOK)
}
err := cb.tmpl.Execute(w, res)
if err != nil {
cb.renderErrCh <- err
}
cb.feedbackCh <- res
}
// Handler opens up a browser waits for redirect to come back from the identity provider
func (cb *callbackServer) Handler(authCodeURL string) (string, string, error) {
err := cb.a.browser(authCodeURL)
if err != nil {
fmt.Printf("Please open %s in the browser to continue authentication", authCodeURL)
}
select {
case <-cb.ctx.Done():
return "", "", cb.ctx.Err()
case renderErr := <-cb.renderErrCh:
return "", "", renderErr
case res := <-cb.feedbackCh:
if res.Error != "" {
return "", "", fmt.Errorf("%s: %s", res.Error, res.ErrorDescription)
}
return res.Code, res.State, nil
}
}

268
libs/auth/oauth.go Normal file
View File

@ -0,0 +1,268 @@
package auth
import (
"context"
"crypto/sha256"
_ "embed"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"strings"
"time"
"github.com/databricks/bricks/libs/auth/cache"
"github.com/databricks/databricks-sdk-go/retries"
"github.com/pkg/browser"
"golang.org/x/oauth2"
"golang.org/x/oauth2/authhandler"
)
const (
// these values are predefined by Databricks as a public client
// and is specific to this application only. Using these values
// for other applications is not allowed.
appClientID = "databricks-cli"
appRedirectAddr = "localhost:8020"
// maximum amount of time to acquire listener on appRedirectAddr
DefaultTimeout = 15 * time.Second
)
var ( // Databricks SDK API: `databricks OAuth is not` will be checked for presence
ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host")
ErrNotConfigured = errors.New("databricks OAuth is not configured for this host")
ErrFetchCredentials = errors.New("cannot fetch credentials")
)
type PersistentAuth struct {
Host string
AccountID string
http httpGet
cache tokenCache
ln net.Listener
browser func(string) error
}
type httpGet interface {
Get(string) (*http.Response, error)
}
type tokenCache interface {
Store(key string, t *oauth2.Token) error
Lookup(key string) (*oauth2.Token, error)
}
func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) {
err := a.init(ctx)
if err != nil {
return nil, fmt.Errorf("init: %w", err)
}
// lookup token identified by host (and possibly the account id)
key := a.key()
t, err := a.cache.Lookup(key)
if err != nil {
return nil, fmt.Errorf("cache: %w", err)
}
// early return for valid tokens
if t.Valid() {
// do not print refresh token to end-user
t.RefreshToken = ""
return t, nil
}
// OAuth2 config is invoked only for expired tokens to speed up
// the happy path in the token retrieval
cfg, err := a.oauth2Config()
if err != nil {
return nil, err
}
// eagerly refresh token
refreshed, err := cfg.TokenSource(ctx, t).Token()
if err != nil {
return nil, fmt.Errorf("token refresh: %w", err)
}
err = a.cache.Store(key, refreshed)
if err != nil {
return nil, fmt.Errorf("cache refresh: %w", err)
}
// do not print refresh token to end-user
refreshed.RefreshToken = ""
return refreshed, nil
}
func (a *PersistentAuth) Challenge(ctx context.Context) error {
err := a.init(ctx)
if err != nil {
return fmt.Errorf("init: %w", err)
}
cfg, err := a.oauth2Config()
if err != nil {
return err
}
cb, err := newCallback(ctx, a)
if err != nil {
return fmt.Errorf("callback server: %w", err)
}
defer cb.Close()
state, pkce := a.stateAndPKCE()
ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce)
t, err := ts.Token()
if err != nil {
return fmt.Errorf("authorize: %w", err)
}
// cache token identified by host (and possibly the account id)
err = a.cache.Store(a.key(), t)
if err != nil {
return fmt.Errorf("store: %w", err)
}
return nil
}
func (a *PersistentAuth) init(ctx context.Context) error {
if a.Host == "" && a.AccountID == "" {
return ErrFetchCredentials
}
if a.http == nil {
a.http = http.DefaultClient
}
if a.cache == nil {
a.cache = &cache.TokenCache{}
}
if a.browser == nil {
a.browser = browser.OpenURL
}
// try acquire listener, which we also use as a machine-local
// exclusive lock to prevent token cache corruption in the scope
// of developer machine, where this command runs.
listener, err := retries.Poll(ctx, DefaultTimeout,
func() (*net.Listener, *retries.Err) {
var lc net.ListenConfig
l, err := lc.Listen(ctx, "tcp", appRedirectAddr)
if err != nil {
return nil, retries.Continue(err)
}
return &l, nil
})
if err != nil {
return fmt.Errorf("listener: %w", err)
}
a.ln = *listener
return nil
}
func (a *PersistentAuth) Close() error {
if a.ln == nil {
return nil
}
return a.ln.Close()
}
func (a *PersistentAuth) oidcEndpoints() (*oauthAuthorizationServer, error) {
prefix := a.key()
if a.AccountID != "" {
return &oauthAuthorizationServer{
AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix),
TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix),
}, nil
}
oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", prefix)
oidcResponse, err := a.http.Get(oidc)
if err != nil {
return nil, fmt.Errorf("fetch .well-known: %w", err)
}
if oidcResponse.StatusCode != 200 {
return nil, ErrOAuthNotSupported
}
if oidcResponse.Body == nil {
return nil, fmt.Errorf("fetch .well-known: empty body")
}
defer oidcResponse.Body.Close()
raw, err := io.ReadAll(oidcResponse.Body)
if err != nil {
return nil, fmt.Errorf("read .well-known: %w", err)
}
var oauthEndpoints oauthAuthorizationServer
err = json.Unmarshal(raw, &oauthEndpoints)
if err != nil {
return nil, fmt.Errorf("parse .well-known: %w", err)
}
return &oauthEndpoints, nil
}
func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) {
// in this iteration of CLI, we're using all scopes by default,
// because tools like CLI and Terraform do use all apis. This
// decision may be reconsidered later, once we have a proper
// taxonomy of all scopes ready and implemented.
scopes := []string{
"offline_access",
"unity-catalog",
"accounts",
"clusters",
"mlflow",
"scim",
"sql",
}
if a.AccountID != "" {
scopes = []string{
"offline_access",
"accounts",
}
}
endpoints, err := a.oidcEndpoints()
if err != nil {
return nil, fmt.Errorf("oidc: %w", err)
}
return &oauth2.Config{
ClientID: appClientID,
Endpoint: oauth2.Endpoint{
AuthURL: endpoints.AuthorizationEndpoint,
TokenURL: endpoints.TokenEndpoint,
AuthStyle: oauth2.AuthStyleInParams,
},
RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr),
Scopes: scopes,
}, nil
}
// key is currently used for two purposes: OIDC URL prefix and token cache key.
// once we decide to start storing scopes in the token cache, we should change
// this approach.
func (a *PersistentAuth) key() string {
a.Host = strings.TrimSuffix(a.Host, "/")
if !strings.HasPrefix(a.Host, "http") {
a.Host = fmt.Sprintf("https://%s", a.Host)
}
if a.AccountID != "" {
return fmt.Sprintf("%s/oidc/accounts/%s", a.Host, a.AccountID)
}
return a.Host
}
func (a *PersistentAuth) stateAndPKCE() (string, *authhandler.PKCEParams) {
verifier := a.randomString(64)
verifierSha256 := sha256.Sum256([]byte(verifier))
challenge := base64.RawURLEncoding.EncodeToString(verifierSha256[:])
return a.randomString(16), &authhandler.PKCEParams{
Challenge: challenge,
ChallengeMethod: "S256",
Verifier: verifier,
}
}
func (a *PersistentAuth) randomString(size int) string {
rand.Seed(time.Now().UnixNano())
raw := make([]byte, size)
_, _ = rand.Read(raw)
return base64.RawURLEncoding.EncodeToString(raw)
}
type oauthAuthorizationServer struct {
AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize
TokenEndpoint string `json:"token_endpoint"` // ../v1/token
}

235
libs/auth/oauth_test.go Normal file
View File

@ -0,0 +1,235 @@
package auth
import (
"context"
"crypto/tls"
_ "embed"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"
"github.com/databricks/databricks-sdk-go/client"
"github.com/databricks/databricks-sdk-go/qa"
"github.com/stretchr/testify/assert"
"golang.org/x/oauth2"
)
func TestOidcEndpointsForAccounts(t *testing.T) {
p := &PersistentAuth{
Host: "abc",
AccountID: "xyz",
}
defer p.Close()
s, err := p.oidcEndpoints()
assert.NoError(t, err)
assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", s.AuthorizationEndpoint)
assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/token", s.TokenEndpoint)
}
type mockGet func(url string) (*http.Response, error)
func (m mockGet) Get(url string) (*http.Response, error) {
return m(url)
}
func TestOidcForWorkspace(t *testing.T) {
p := &PersistentAuth{
Host: "abc",
http: mockGet(func(url string) (*http.Response, error) {
assert.Equal(t, "https://abc/oidc/.well-known/oauth-authorization-server", url)
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader(`{
"authorization_endpoint": "a",
"token_endpoint": "b"
}`)),
}, nil
}),
}
defer p.Close()
endpoints, err := p.oidcEndpoints()
assert.NoError(t, err)
assert.Equal(t, "a", endpoints.AuthorizationEndpoint)
assert.Equal(t, "b", endpoints.TokenEndpoint)
}
type tokenCacheMock struct {
store func(key string, t *oauth2.Token) error
lookup func(key string) (*oauth2.Token, error)
}
func (m *tokenCacheMock) Store(key string, t *oauth2.Token) error {
if m.store == nil {
panic("no store mock")
}
return m.store(key, t)
}
func (m *tokenCacheMock) Lookup(key string) (*oauth2.Token, error) {
if m.lookup == nil {
panic("no lookup mock")
}
return m.lookup(key)
}
func TestLoad(t *testing.T) {
p := &PersistentAuth{
Host: "abc",
AccountID: "xyz",
cache: &tokenCacheMock{
lookup: func(key string) (*oauth2.Token, error) {
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
return &oauth2.Token{
AccessToken: "bcd",
Expiry: time.Now().Add(1 * time.Minute),
}, nil
},
},
}
defer p.Close()
tok, err := p.Load(context.Background())
assert.NoError(t, err)
assert.Equal(t, "bcd", tok.AccessToken)
assert.Equal(t, "", tok.RefreshToken)
}
func useInsecureOAuthHttpClientForTests(ctx context.Context) context.Context {
return context.WithValue(ctx, oauth2.HTTPClient, &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
})
}
func TestLoadRefresh(t *testing.T) {
qa.HTTPFixtures{
{
Method: "POST",
Resource: "/oidc/accounts/xyz/v1/token",
Response: `access_token=refreshed&refresh_token=def`,
},
}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
ctx = useInsecureOAuthHttpClientForTests(ctx)
expectedKey := fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host)
p := &PersistentAuth{
Host: c.Config.Host,
AccountID: "xyz",
cache: &tokenCacheMock{
lookup: func(key string) (*oauth2.Token, error) {
assert.Equal(t, expectedKey, key)
return &oauth2.Token{
AccessToken: "expired",
RefreshToken: "cde",
Expiry: time.Now().Add(-1 * time.Minute),
}, nil
},
store: func(key string, tok *oauth2.Token) error {
assert.Equal(t, expectedKey, key)
assert.Equal(t, "def", tok.RefreshToken)
return nil
},
},
}
defer p.Close()
tok, err := p.Load(ctx)
assert.NoError(t, err)
assert.Equal(t, "refreshed", tok.AccessToken)
assert.Equal(t, "", tok.RefreshToken)
})
}
func TestChallenge(t *testing.T) {
qa.HTTPFixtures{
{
Method: "POST",
Resource: "/oidc/accounts/xyz/v1/token",
Response: `access_token=__THAT__&refresh_token=__SOMETHING__`,
},
}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
ctx = useInsecureOAuthHttpClientForTests(ctx)
expectedKey := fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host)
browserOpened := make(chan string)
p := &PersistentAuth{
Host: c.Config.Host,
AccountID: "xyz",
browser: func(redirect string) error {
u, err := url.ParseRequestURI(redirect)
if err != nil {
return err
}
assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path)
// for now we're ignoring asserting the fields of the redirect
query := u.Query()
browserOpened <- query.Get("state")
return nil
},
cache: &tokenCacheMock{
store: func(key string, tok *oauth2.Token) error {
assert.Equal(t, expectedKey, key)
assert.Equal(t, "__SOMETHING__", tok.RefreshToken)
return nil
},
},
}
defer p.Close()
errc := make(chan error)
go func() {
errc <- p.Challenge(ctx)
}()
state := <-browserOpened
resp, err := http.Get(fmt.Sprintf("http://%s?code=__THIS__&state=%s", appRedirectAddr, state))
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
err = <-errc
assert.NoError(t, err)
})
}
func TestChallengeFailed(t *testing.T) {
qa.HTTPFixtures{}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
ctx = useInsecureOAuthHttpClientForTests(ctx)
browserOpened := make(chan string)
p := &PersistentAuth{
Host: c.Config.Host,
AccountID: "xyz",
browser: func(redirect string) error {
u, err := url.ParseRequestURI(redirect)
if err != nil {
return err
}
assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path)
// for now we're ignoring asserting the fields of the redirect
query := u.Query()
browserOpened <- query.Get("state")
return nil
},
}
defer p.Close()
errc := make(chan error)
go func() {
errc <- p.Challenge(ctx)
}()
<-browserOpened
resp, err := http.Get(fmt.Sprintf(
"http://%s?error=access_denied&error_description=Policy%%20evaluation%%20failed%%20for%%20this%%20request",
appRedirectAddr))
assert.NoError(t, err)
assert.Equal(t, 400, resp.StatusCode)
err = <-errc
assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request")
})
}

102
libs/auth/page.tmpl Normal file
View File

@ -0,0 +1,102 @@
<!DOCTYPE html SYSTEM "http://www.thymeleaf.org/dtd/xhtml1-strict-thymeleaf-4.dtd">
<html xmlns="http://www.w3.org/1999/xhtml" xmlns:th="http://www.thymeleaf.org">
<head>
<title>{{if .Error }}{{ .Error | title }}{{ else }}Success{{end}}</title>
<link rel="preconnect" href="https://fonts.gstatic.com" />
<link href="https://fonts.googleapis.com/css2?family=DM+Sans&display=swap" rel="stylesheet" />
<style>
html,
body {
height: 100%;
}
body {
font-family: "DM Sans";
font-style: normal;
font-size: 14px;
margin: 0;
padding: 0;
height: 100%;
width: 100%;
background: #f5f6f6;
align-items: center;
}
.root-container {
display: flex;
height: 100%;
align-items: center;
justify-content: center;
}
.info-container {
width: 320px;
box-shadow: 0px 2px 4px rgba(0, 0, 0, 0.1),
0px 8px 25px rgba(0, 0, 0, 0.1);
border-radius: 8px;
display: flex;
flex-direction: column;
padding: 48px;
background: #fff;
justify-content: center;
align-items: center;
text-align: center;
gap: 24px;
}
.title {
font-weight: 600;
font-size: 24px;
line-height: 28px;
}
a {
color: #C4CCD6;
}
a:hover {
color: #90A5B1;
}
.content {
width: 300px;
font-size: 14px;
}
.button {
display: flex;
background: #1B3139;
align-items: center;
justify-content: center;
height: 40px;
width: 300px;
border-radius: 4px;
text-align: center;
text-decoration: none;
color: #ffffff !important;
}
</style>
</head>
<body>
<div class="root-container">
<div class="info-container">
<img
src=""
/>
<!-- {{ if .Error }} -->
<div class="title">{{ .Error | title }}</div>
<div class="content">{{ .ErrorDescription }}</div>
<!-- {{ else }} -->
<div class="title">Authenticated</div>
<div class="content">Go to <a href="https://{{.Host}}">{{.Host}}</a></div>
<!-- {{ end }} -->
<div class="content">
You can close this tab. Or go to <a href="https://docs.databricks.com/dev-tools/index-cli.html">documentation</a>
</div>
</div>
</div>
</body>
</html>

View File

@ -2,6 +2,7 @@ package main
import ( import (
_ "github.com/databricks/bricks/cmd/api" _ "github.com/databricks/bricks/cmd/api"
_ "github.com/databricks/bricks/cmd/auth"
_ "github.com/databricks/bricks/cmd/bundle" _ "github.com/databricks/bricks/cmd/bundle"
_ "github.com/databricks/bricks/cmd/bundle/debug" _ "github.com/databricks/bricks/cmd/bundle/debug"
_ "github.com/databricks/bricks/cmd/bundle/debug/deploy" _ "github.com/databricks/bricks/cmd/bundle/debug/deploy"