diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 00000000..9bf69387 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,5 @@ +{ + "recommendations": [ + "bierner.markdown-mermaid" + ] +} \ No newline at end of file diff --git a/cmd/auth/README.md b/cmd/auth/README.md new file mode 100644 index 00000000..6798f12e --- /dev/null +++ b/cmd/auth/README.md @@ -0,0 +1,51 @@ +# Auth challenge (happy path) + +Simplified description of [PKCE](https://oauth.net/2/pkce/) implementation: + +```mermaid +sequenceDiagram + autonumber + actor User + + User ->> CLI: type `bricks auth login HOST` + CLI ->>+ HOST: request OIDC endpoints + HOST ->>- CLI: auth & token endpoints + CLI ->> CLI: start embedded server to consume redirects (lock) + CLI -->>+ Auth Endpoint: open browser with RND1 + SHA256(RND2) + + User ->>+ Auth Endpoint: Go through SSO + Auth Endpoint ->>- CLI: AUTH CODE + 'RND1 (redirect) + + CLI ->>+ Token Endpoint: Exchange: AUTH CODE + RND2 + Token Endpoint ->>- CLI: Access Token (JWT) + refresh + expiry + CLI ->> Token cache: Save Access Token (JWT) + refresh + expiry + CLI ->> User: success +``` + +# Token refresh (happy path) + +```mermaid +sequenceDiagram + autonumber + actor User + + User ->> CLI: type `bricks token HOST` + + CLI ->> CLI: acquire lock (same local addr as redirect server) + CLI ->>+ Token cache: read token + + critical token not expired + Token cache ->>- User: JWT (without refresh) + + option token is expired + CLI ->>+ HOST: request OIDC endpoints + HOST ->>- CLI: auth & token endpoints + CLI ->>+ Token Endpoint: refresh token + Token Endpoint ->>- CLI: JWT (refreshed) + CLI ->> Token cache: save JWT (refreshed) + CLI ->> User: JWT (refreshed) + + option no auth for host + CLI -X User: no auth configured + end +``` \ No newline at end of file diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go new file mode 100644 index 00000000..13c2068e --- /dev/null +++ b/cmd/auth/auth.go @@ -0,0 +1,20 @@ +package auth + +import ( + "github.com/databricks/bricks/cmd/root" + "github.com/databricks/bricks/libs/auth" + "github.com/spf13/cobra" +) + +var authCmd = &cobra.Command{ + Use: "auth", + Short: "Authentication related commands", +} + +var perisistentAuth auth.PersistentAuth + +func init() { + root.RootCmd.AddCommand(authCmd) + authCmd.PersistentFlags().StringVar(&perisistentAuth.Host, "host", perisistentAuth.Host, "Databricks Host") + authCmd.PersistentFlags().StringVar(&perisistentAuth.AccountID, "account-id", perisistentAuth.AccountID, "Databricks Account ID") +} diff --git a/cmd/auth/env.go b/cmd/auth/env.go new file mode 100644 index 00000000..bf85fbd2 --- /dev/null +++ b/cmd/auth/env.go @@ -0,0 +1,143 @@ +package auth + +import ( + "encoding/json" + "errors" + "fmt" + "io/fs" + "net/http" + "net/url" + "strings" + + "github.com/databricks/databricks-sdk-go/config" + "github.com/spf13/cobra" + "gopkg.in/ini.v1" +) + +func canonicalHost(host string) (string, error) { + parsedHost, err := url.Parse(host) + if err != nil { + return "", err + } + // If the host is empty, assume the scheme wasn't included. + if parsedHost.Host == "" { + return fmt.Sprintf("https://%s", host), nil + } + return fmt.Sprintf("https://%s", parsedHost.Host), nil +} + +var ErrNoMatchingProfiles = errors.New("no matching profiles found") + +func resolveSection(cfg *config.Config, iniFile *ini.File) (*ini.Section, error) { + var candidates []*ini.Section + configuredHost, err := canonicalHost(cfg.Host) + if err != nil { + return nil, err + } + for _, section := range iniFile.Sections() { + hash := section.KeysHash() + host, ok := hash["host"] + if !ok { + // if host is not set + continue + } + canonical, err := canonicalHost(host) + if err != nil { + // we're fine with other corrupt profiles + continue + } + if canonical != configuredHost { + continue + } + candidates = append(candidates, section) + } + if len(candidates) == 0 { + return nil, ErrNoMatchingProfiles + } + // in the real situations, we don't expect this to happen often + // (if not at all), hence we don't trim the list + if len(candidates) > 1 { + var profiles []string + for _, v := range candidates { + profiles = append(profiles, v.Name()) + } + return nil, fmt.Errorf("%s match %s in %s", + strings.Join(profiles, " and "), cfg.Host, cfg.ConfigFile) + } + return candidates[0], nil +} + +func loadFromDatabricksCfg(cfg *config.Config) error { + iniFile, err := getDatabricksCfg() + if errors.Is(err, fs.ErrNotExist) { + // it's fine not to have ~/.databrickscfg + return nil + } + if err != nil { + return err + } + profile, err := resolveSection(cfg, iniFile) + if err == ErrNoMatchingProfiles { + // it's also fine for Azure CLI or Bricks CLI, which + // are resolved by unified auth handling in the Go SDK. + return nil + } + if err != nil { + return err + } + cfg.Profile = profile.Name() + return nil +} + +var envCmd = &cobra.Command{ + Use: "env", + Short: "Get env", + RunE: func(cmd *cobra.Command, args []string) error { + cfg := &config.Config{ + Host: host, + Profile: profile, + } + if profile != "" { + cfg.Profile = profile + } else if cfg.Host == "" { + cfg.Profile = "DEFAULT" + } else if err := loadFromDatabricksCfg(cfg); err != nil { + return err + } + // Go SDK is lazy loaded because of Terraform semantics, + // so we're creating a dummy HTTP request as a placeholder + // for headers. + r := &http.Request{Header: http.Header{}} + err := cfg.Authenticate(r.WithContext(cmd.Context())) + if err != nil { + return err + } + vars := map[string]string{} + for _, a := range config.ConfigAttributes { + if a.IsZero(cfg) { + continue + } + envValue := a.GetString(cfg) + for _, envName := range a.EnvVars { + vars[envName] = envValue + } + } + raw, err := json.MarshalIndent(map[string]any{ + "env": vars, + }, "", " ") + if err != nil { + return err + } + cmd.OutOrStdout().Write(raw) + return nil + }, +} + +var host string +var profile string + +func init() { + authCmd.AddCommand(envCmd) + envCmd.Flags().StringVar(&host, "host", host, "Hostname to get auth env for") + envCmd.Flags().StringVar(&profile, "profile", profile, "Profile to get auth env for") +} diff --git a/cmd/auth/login.go b/cmd/auth/login.go new file mode 100644 index 00000000..39ea69ca --- /dev/null +++ b/cmd/auth/login.go @@ -0,0 +1,31 @@ +package auth + +import ( + "context" + "time" + + "github.com/databricks/bricks/libs/auth" + "github.com/spf13/cobra" +) + +var loginTimeout time.Duration + +var loginCmd = &cobra.Command{ + Use: "login [HOST]", + Short: "Authenticate this machine", + RunE: func(cmd *cobra.Command, args []string) error { + if perisistentAuth.Host == "" && len(args) == 1 { + perisistentAuth.Host = args[0] + } + defer perisistentAuth.Close() + ctx, cancel := context.WithTimeout(cmd.Context(), loginTimeout) + defer cancel() + return perisistentAuth.Challenge(ctx) + }, +} + +func init() { + authCmd.AddCommand(loginCmd) + loginCmd.Flags().DurationVar(&loginTimeout, "timeout", auth.DefaultTimeout, + "Timeout for completing login challenge in the browser") +} diff --git a/cmd/auth/profiles.go b/cmd/auth/profiles.go new file mode 100644 index 00000000..92b515f1 --- /dev/null +++ b/cmd/auth/profiles.go @@ -0,0 +1,131 @@ +package auth + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/config" + "github.com/spf13/cobra" + "gopkg.in/ini.v1" +) + +func getDatabricksCfg() (*ini.File, error) { + configFile := os.Getenv("DATABRICKS_CONFIG_FILE") + if configFile == "" { + configFile = "~/.databrickscfg" + } + if strings.HasPrefix(configFile, "~") { + homedir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("cannot find homedir: %w", err) + } + configFile = filepath.Join(homedir, configFile[1:]) + } + return ini.Load(configFile) +} + +type profileMetadata struct { + Name string `json:"name"` + Host string `json:"host,omitempty"` + AccountID string `json:"account_id,omitempty"` + Cloud string `json:"cloud"` + AuthType string `json:"auth_type"` + Valid bool `json:"valid"` +} + +func (c *profileMetadata) IsEmpty() bool { + return c.Host == "" && c.AccountID == "" +} + +func (c *profileMetadata) Load(ctx context.Context) { + // TODO: disable config loaders other than configfile + cfg := &config.Config{Profile: c.Name} + _ = cfg.EnsureResolved() + if cfg.IsAws() { + c.Cloud = "aws" + } else if cfg.IsAzure() { + c.Cloud = "azure" + } else if cfg.IsGcp() { + c.Cloud = "gcp" + } + if cfg.IsAccountClient() { + a, err := databricks.NewAccountClient((*databricks.Config)(cfg)) + if err != nil { + return + } + _, err = a.Workspaces.List(ctx) + c.AuthType = cfg.AuthType + if err != nil { + return + } + c.Valid = true + } else { + w, err := databricks.NewWorkspaceClient((*databricks.Config)(cfg)) + if err != nil { + return + } + _, err = w.Tokens.ListAll(ctx) + c.AuthType = cfg.AuthType + if err != nil { + return + } + c.Valid = true + } + // set host again, this time normalized + c.Host = cfg.Host +} + +var profilesCmd = &cobra.Command{ + Use: "profiles", + Short: "Lists profiles from ~/.databrickscfg", + RunE: func(cmd *cobra.Command, args []string) error { + iniFile, err := getDatabricksCfg() + if os.IsNotExist(err) { + // early return for non-configured machines + return errors.New("~/.databrickcfg not found on current host") + } + if err != nil { + return fmt.Errorf("cannot parse config file: %w", err) + } + var profiles []*profileMetadata + var wg sync.WaitGroup + for _, v := range iniFile.Sections() { + hash := v.KeysHash() + profile := &profileMetadata{ + Name: v.Name(), + Host: hash["host"], + AccountID: hash["account_id"], + } + if profile.IsEmpty() { + continue + } + wg.Add(1) + go func() { + // load more information about profile + profile.Load(cmd.Context()) + wg.Done() + }() + profiles = append(profiles, profile) + } + wg.Wait() + raw, err := json.MarshalIndent(map[string]any{ + "profiles": profiles, + }, "", " ") + if err != nil { + return err + } + cmd.OutOrStdout().Write(raw) + return nil + }, +} + +func init() { + authCmd.AddCommand(profilesCmd) +} diff --git a/cmd/auth/token.go b/cmd/auth/token.go new file mode 100644 index 00000000..e1f9f405 --- /dev/null +++ b/cmd/auth/token.go @@ -0,0 +1,41 @@ +package auth + +import ( + "context" + "encoding/json" + "time" + + "github.com/databricks/bricks/libs/auth" + "github.com/spf13/cobra" +) + +var tokenTimeout time.Duration + +var tokenCmd = &cobra.Command{ + Use: "token [HOST]", + Short: "Get authentication token", + RunE: func(cmd *cobra.Command, args []string) error { + if perisistentAuth.Host == "" && len(args) == 1 { + perisistentAuth.Host = args[0] + } + defer perisistentAuth.Close() + ctx, cancel := context.WithTimeout(cmd.Context(), tokenTimeout) + defer cancel() + t, err := perisistentAuth.Load(ctx) + if err != nil { + return err + } + raw, err := json.MarshalIndent(t, "", " ") + if err != nil { + return err + } + cmd.OutOrStdout().Write(raw) + return nil + }, +} + +func init() { + authCmd.AddCommand(tokenCmd) + tokenCmd.Flags().DurationVar(&tokenTimeout, "timeout", auth.DefaultTimeout, + "Timeout for acquiring a token.") +} diff --git a/go.mod b/go.mod index 8bb503e2..d3d17e82 100644 --- a/go.mod +++ b/go.mod @@ -48,9 +48,9 @@ require ( github.com/spf13/pflag v1.0.5 go.opencensus.io v0.24.0 // indirect golang.org/x/net v0.1.0 // indirect - golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 // indirect + golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 golang.org/x/sys v0.1.0 // indirect - golang.org/x/text v0.5.0 // indirect + golang.org/x/text v0.5.0 golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac // indirect google.golang.org/api v0.105.0 // indirect google.golang.org/appengine v1.6.7 // indirect diff --git a/libs/auth/cache/cache.go b/libs/auth/cache/cache.go new file mode 100644 index 00000000..5511c192 --- /dev/null +++ b/libs/auth/cache/cache.go @@ -0,0 +1,106 @@ +package cache + +import ( + "encoding/json" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + + "golang.org/x/oauth2" +) + +const ( + // where the token cache is stored + tokenCacheFile = ".databricks/token-cache.json" + + // only the owner of the file has full execute, read, and write access + ownerExecReadWrite = 0o700 + + // only the owner of the file has full read and write access + ownerReadWrite = 0o600 + + // format versioning leaves some room for format improvement + tokenCacheVersion = 1 +) + +var ErrNotConfigured = errors.New("databricks OAuth is not configured for this host") + +// this implementation requires the calling code to do a machine-wide lock, +// otherwise the file might get corrupt. +type TokenCache struct { + Version int `json:"version"` + Tokens map[string]*oauth2.Token `json:"tokens"` + + fileLocation string +} + +func (c *TokenCache) Store(key string, t *oauth2.Token) error { + err := c.load() + if errors.Is(err, fs.ErrNotExist) { + dir := filepath.Dir(c.fileLocation) + err = os.MkdirAll(dir, ownerExecReadWrite) + if err != nil { + return fmt.Errorf("mkdir: %w", err) + } + } else if err != nil { + return fmt.Errorf("load: %w", err) + } + c.Version = tokenCacheVersion + if c.Tokens == nil { + c.Tokens = map[string]*oauth2.Token{} + } + c.Tokens[key] = t + raw, err := json.MarshalIndent(c, "", " ") + if err != nil { + return fmt.Errorf("marshal: %w", err) + } + return os.WriteFile(c.fileLocation, raw, ownerReadWrite) +} + +func (c *TokenCache) Lookup(key string) (*oauth2.Token, error) { + err := c.load() + if errors.Is(err, fs.ErrNotExist) { + return nil, ErrNotConfigured + } else if err != nil { + return nil, fmt.Errorf("load: %w", err) + } + t, ok := c.Tokens[key] + if !ok { + return nil, ErrNotConfigured + } + return t, nil +} + +func (c *TokenCache) location() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("home: %w", err) + } + return filepath.Join(home, tokenCacheFile), nil +} + +func (c *TokenCache) load() error { + loc, err := c.location() + if err != nil { + return err + } + c.fileLocation = loc + raw, err := os.ReadFile(loc) + if err != nil { + return fmt.Errorf("read: %w", err) + } + err = json.Unmarshal(raw, c) + if err != nil { + return fmt.Errorf("parse: %w", err) + } + if c.Version != tokenCacheVersion { + // in the later iterations we could do state upgraders, + // so that we transform token cache from v1 to v2 without + // losing the tokens and asking the user to re-authenticate. + return fmt.Errorf("needs version %d, got version %d", + tokenCacheVersion, c.Version) + } + return nil +} diff --git a/libs/auth/cache/cache_test.go b/libs/auth/cache/cache_test.go new file mode 100644 index 00000000..6529882c --- /dev/null +++ b/libs/auth/cache/cache_test.go @@ -0,0 +1,105 @@ +package cache + +import ( + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +var homeEnvVar = "HOME" + +func init() { + if runtime.GOOS == "windows" { + homeEnvVar = "USERPROFILE" + } +} + +func setup(t *testing.T) string { + tempHomeDir := t.TempDir() + t.Setenv(homeEnvVar, tempHomeDir) + return tempHomeDir +} + +func TestStoreAndLookup(t *testing.T) { + setup(t) + c := &TokenCache{} + err := c.Store("x", &oauth2.Token{ + AccessToken: "abc", + }) + require.NoError(t, err) + + err = c.Store("y", &oauth2.Token{ + AccessToken: "bcd", + }) + require.NoError(t, err) + + l := &TokenCache{} + tok, err := l.Lookup("x") + require.NoError(t, err) + assert.Equal(t, "abc", tok.AccessToken) + assert.Equal(t, 2, len(l.Tokens)) + + _, err = l.Lookup("z") + assert.Equal(t, ErrNotConfigured, err) +} + +func TestNoCacheFileReturnsErrNotConfigured(t *testing.T) { + setup(t) + l := &TokenCache{} + _, err := l.Lookup("x") + assert.Equal(t, ErrNotConfigured, err) +} + +func TestLoadCorruptFile(t *testing.T) { + home := setup(t) + f := filepath.Join(home, tokenCacheFile) + err := os.MkdirAll(filepath.Dir(f), ownerExecReadWrite) + require.NoError(t, err) + err = os.WriteFile(f, []byte("abc"), ownerExecReadWrite) + require.NoError(t, err) + + l := &TokenCache{} + _, err = l.Lookup("x") + assert.EqualError(t, err, "load: parse: invalid character 'a' looking for beginning of value") +} + +func TestLoadWrongVersion(t *testing.T) { + home := setup(t) + f := filepath.Join(home, tokenCacheFile) + err := os.MkdirAll(filepath.Dir(f), ownerExecReadWrite) + require.NoError(t, err) + err = os.WriteFile(f, []byte(`{"version": 823, "things": []}`), ownerExecReadWrite) + require.NoError(t, err) + + l := &TokenCache{} + _, err = l.Lookup("x") + assert.EqualError(t, err, "load: needs version 1, got version 823") +} + +func TestDevNull(t *testing.T) { + t.Setenv(homeEnvVar, "/dev/null") + l := &TokenCache{} + _, err := l.Lookup("x") + // macOS/Linux: load: read: open /dev/null/.databricks/token-cache.json: + // windows: databricks OAuth is not configured for this host + assert.Error(t, err) +} + +func TestStoreOnDev(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + t.Setenv(homeEnvVar, "/dev") + c := &TokenCache{} + err := c.Store("x", &oauth2.Token{ + AccessToken: "abc", + }) + // Linux: permission denied + // macOS: read-only file system + assert.Error(t, err) +} diff --git a/libs/auth/callback.go b/libs/auth/callback.go new file mode 100644 index 00000000..5a240069 --- /dev/null +++ b/libs/auth/callback.go @@ -0,0 +1,102 @@ +package auth + +import ( + "context" + _ "embed" + "fmt" + "html/template" + "net" + "net/http" + "strings" + + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +//go:embed page.tmpl +var pageTmpl string + +type oauthResult struct { + Error string + ErrorDescription string + Host string + State string + Code string +} + +type callbackServer struct { + ln net.Listener + srv http.Server + ctx context.Context + a *PersistentAuth + renderErrCh chan error + feedbackCh chan oauthResult + tmpl *template.Template +} + +func newCallback(ctx context.Context, a *PersistentAuth) (*callbackServer, error) { + tmpl, err := template.New("page").Funcs(template.FuncMap{ + "title": func(in string) string { + title := cases.Title(language.English) + return title.String(strings.ReplaceAll(in, "_", " ")) + }, + }).Parse(pageTmpl) + if err != nil { + return nil, err + } + cb := &callbackServer{ + feedbackCh: make(chan oauthResult), + renderErrCh: make(chan error), + tmpl: tmpl, + ctx: ctx, + ln: a.ln, + a: a, + } + cb.srv.Handler = cb + go cb.srv.Serve(cb.ln) + return cb, nil +} + +func (cb *callbackServer) Close() error { + return cb.srv.Close() +} + +// ServeHTTP renders page.html template +func (cb *callbackServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + res := oauthResult{ + Error: r.FormValue("error"), + ErrorDescription: r.FormValue("error_description"), + Code: r.FormValue("code"), + State: r.FormValue("state"), + Host: cb.a.Host, + } + if res.Error != "" { + w.WriteHeader(http.StatusBadRequest) + } else { + w.WriteHeader(http.StatusOK) + } + err := cb.tmpl.Execute(w, res) + if err != nil { + cb.renderErrCh <- err + } + cb.feedbackCh <- res +} + +// Handler opens up a browser waits for redirect to come back from the identity provider +func (cb *callbackServer) Handler(authCodeURL string) (string, string, error) { + err := cb.a.browser(authCodeURL) + if err != nil { + fmt.Printf("Please open %s in the browser to continue authentication", authCodeURL) + } + select { + case <-cb.ctx.Done(): + return "", "", cb.ctx.Err() + case renderErr := <-cb.renderErrCh: + return "", "", renderErr + case res := <-cb.feedbackCh: + if res.Error != "" { + return "", "", fmt.Errorf("%s: %s", res.Error, res.ErrorDescription) + } + return res.Code, res.State, nil + } +} diff --git a/libs/auth/oauth.go b/libs/auth/oauth.go new file mode 100644 index 00000000..b52b59e1 --- /dev/null +++ b/libs/auth/oauth.go @@ -0,0 +1,268 @@ +package auth + +import ( + "context" + "crypto/sha256" + _ "embed" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "math/rand" + "net" + "net/http" + "strings" + "time" + + "github.com/databricks/bricks/libs/auth/cache" + "github.com/databricks/databricks-sdk-go/retries" + "github.com/pkg/browser" + "golang.org/x/oauth2" + "golang.org/x/oauth2/authhandler" +) + +const ( + // these values are predefined by Databricks as a public client + // and is specific to this application only. Using these values + // for other applications is not allowed. + appClientID = "databricks-cli" + appRedirectAddr = "localhost:8020" + + // maximum amount of time to acquire listener on appRedirectAddr + DefaultTimeout = 15 * time.Second +) + +var ( // Databricks SDK API: `databricks OAuth is not` will be checked for presence + ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host") + ErrNotConfigured = errors.New("databricks OAuth is not configured for this host") + ErrFetchCredentials = errors.New("cannot fetch credentials") +) + +type PersistentAuth struct { + Host string + AccountID string + + http httpGet + cache tokenCache + ln net.Listener + browser func(string) error +} + +type httpGet interface { + Get(string) (*http.Response, error) +} + +type tokenCache interface { + Store(key string, t *oauth2.Token) error + Lookup(key string) (*oauth2.Token, error) +} + +func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) { + err := a.init(ctx) + if err != nil { + return nil, fmt.Errorf("init: %w", err) + } + // lookup token identified by host (and possibly the account id) + key := a.key() + t, err := a.cache.Lookup(key) + if err != nil { + return nil, fmt.Errorf("cache: %w", err) + } + // early return for valid tokens + if t.Valid() { + // do not print refresh token to end-user + t.RefreshToken = "" + return t, nil + } + // OAuth2 config is invoked only for expired tokens to speed up + // the happy path in the token retrieval + cfg, err := a.oauth2Config() + if err != nil { + return nil, err + } + // eagerly refresh token + refreshed, err := cfg.TokenSource(ctx, t).Token() + if err != nil { + return nil, fmt.Errorf("token refresh: %w", err) + } + err = a.cache.Store(key, refreshed) + if err != nil { + return nil, fmt.Errorf("cache refresh: %w", err) + } + // do not print refresh token to end-user + refreshed.RefreshToken = "" + return refreshed, nil +} + +func (a *PersistentAuth) Challenge(ctx context.Context) error { + err := a.init(ctx) + if err != nil { + return fmt.Errorf("init: %w", err) + } + cfg, err := a.oauth2Config() + if err != nil { + return err + } + cb, err := newCallback(ctx, a) + if err != nil { + return fmt.Errorf("callback server: %w", err) + } + defer cb.Close() + state, pkce := a.stateAndPKCE() + ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce) + t, err := ts.Token() + if err != nil { + return fmt.Errorf("authorize: %w", err) + } + // cache token identified by host (and possibly the account id) + err = a.cache.Store(a.key(), t) + if err != nil { + return fmt.Errorf("store: %w", err) + } + return nil +} + +func (a *PersistentAuth) init(ctx context.Context) error { + if a.Host == "" && a.AccountID == "" { + return ErrFetchCredentials + } + if a.http == nil { + a.http = http.DefaultClient + } + if a.cache == nil { + a.cache = &cache.TokenCache{} + } + if a.browser == nil { + a.browser = browser.OpenURL + } + // try acquire listener, which we also use as a machine-local + // exclusive lock to prevent token cache corruption in the scope + // of developer machine, where this command runs. + listener, err := retries.Poll(ctx, DefaultTimeout, + func() (*net.Listener, *retries.Err) { + var lc net.ListenConfig + l, err := lc.Listen(ctx, "tcp", appRedirectAddr) + if err != nil { + return nil, retries.Continue(err) + } + return &l, nil + }) + if err != nil { + return fmt.Errorf("listener: %w", err) + } + a.ln = *listener + return nil +} + +func (a *PersistentAuth) Close() error { + if a.ln == nil { + return nil + } + return a.ln.Close() +} + +func (a *PersistentAuth) oidcEndpoints() (*oauthAuthorizationServer, error) { + prefix := a.key() + if a.AccountID != "" { + return &oauthAuthorizationServer{ + AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix), + TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix), + }, nil + } + oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", prefix) + oidcResponse, err := a.http.Get(oidc) + if err != nil { + return nil, fmt.Errorf("fetch .well-known: %w", err) + } + if oidcResponse.StatusCode != 200 { + return nil, ErrOAuthNotSupported + } + if oidcResponse.Body == nil { + return nil, fmt.Errorf("fetch .well-known: empty body") + } + defer oidcResponse.Body.Close() + raw, err := io.ReadAll(oidcResponse.Body) + if err != nil { + return nil, fmt.Errorf("read .well-known: %w", err) + } + var oauthEndpoints oauthAuthorizationServer + err = json.Unmarshal(raw, &oauthEndpoints) + if err != nil { + return nil, fmt.Errorf("parse .well-known: %w", err) + } + return &oauthEndpoints, nil +} + +func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) { + // in this iteration of CLI, we're using all scopes by default, + // because tools like CLI and Terraform do use all apis. This + // decision may be reconsidered later, once we have a proper + // taxonomy of all scopes ready and implemented. + scopes := []string{ + "offline_access", + "unity-catalog", + "accounts", + "clusters", + "mlflow", + "scim", + "sql", + } + if a.AccountID != "" { + scopes = []string{ + "offline_access", + "accounts", + } + } + endpoints, err := a.oidcEndpoints() + if err != nil { + return nil, fmt.Errorf("oidc: %w", err) + } + return &oauth2.Config{ + ClientID: appClientID, + Endpoint: oauth2.Endpoint{ + AuthURL: endpoints.AuthorizationEndpoint, + TokenURL: endpoints.TokenEndpoint, + AuthStyle: oauth2.AuthStyleInParams, + }, + RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr), + Scopes: scopes, + }, nil +} + +// key is currently used for two purposes: OIDC URL prefix and token cache key. +// once we decide to start storing scopes in the token cache, we should change +// this approach. +func (a *PersistentAuth) key() string { + a.Host = strings.TrimSuffix(a.Host, "/") + if !strings.HasPrefix(a.Host, "http") { + a.Host = fmt.Sprintf("https://%s", a.Host) + } + if a.AccountID != "" { + return fmt.Sprintf("%s/oidc/accounts/%s", a.Host, a.AccountID) + } + return a.Host +} + +func (a *PersistentAuth) stateAndPKCE() (string, *authhandler.PKCEParams) { + verifier := a.randomString(64) + verifierSha256 := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(verifierSha256[:]) + return a.randomString(16), &authhandler.PKCEParams{ + Challenge: challenge, + ChallengeMethod: "S256", + Verifier: verifier, + } +} + +func (a *PersistentAuth) randomString(size int) string { + rand.Seed(time.Now().UnixNano()) + raw := make([]byte, size) + _, _ = rand.Read(raw) + return base64.RawURLEncoding.EncodeToString(raw) +} + +type oauthAuthorizationServer struct { + AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize + TokenEndpoint string `json:"token_endpoint"` // ../v1/token +} diff --git a/libs/auth/oauth_test.go b/libs/auth/oauth_test.go new file mode 100644 index 00000000..9b5aa9ac --- /dev/null +++ b/libs/auth/oauth_test.go @@ -0,0 +1,235 @@ +package auth + +import ( + "context" + "crypto/tls" + _ "embed" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/databricks/databricks-sdk-go/client" + "github.com/databricks/databricks-sdk-go/qa" + "github.com/stretchr/testify/assert" + "golang.org/x/oauth2" +) + +func TestOidcEndpointsForAccounts(t *testing.T) { + p := &PersistentAuth{ + Host: "abc", + AccountID: "xyz", + } + defer p.Close() + s, err := p.oidcEndpoints() + assert.NoError(t, err) + assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", s.AuthorizationEndpoint) + assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/token", s.TokenEndpoint) +} + +type mockGet func(url string) (*http.Response, error) + +func (m mockGet) Get(url string) (*http.Response, error) { + return m(url) +} + +func TestOidcForWorkspace(t *testing.T) { + p := &PersistentAuth{ + Host: "abc", + http: mockGet(func(url string) (*http.Response, error) { + assert.Equal(t, "https://abc/oidc/.well-known/oauth-authorization-server", url) + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{ + "authorization_endpoint": "a", + "token_endpoint": "b" + }`)), + }, nil + }), + } + defer p.Close() + endpoints, err := p.oidcEndpoints() + assert.NoError(t, err) + assert.Equal(t, "a", endpoints.AuthorizationEndpoint) + assert.Equal(t, "b", endpoints.TokenEndpoint) +} + +type tokenCacheMock struct { + store func(key string, t *oauth2.Token) error + lookup func(key string) (*oauth2.Token, error) +} + +func (m *tokenCacheMock) Store(key string, t *oauth2.Token) error { + if m.store == nil { + panic("no store mock") + } + return m.store(key, t) +} + +func (m *tokenCacheMock) Lookup(key string) (*oauth2.Token, error) { + if m.lookup == nil { + panic("no lookup mock") + } + return m.lookup(key) +} + +func TestLoad(t *testing.T) { + p := &PersistentAuth{ + Host: "abc", + AccountID: "xyz", + cache: &tokenCacheMock{ + lookup: func(key string) (*oauth2.Token, error) { + assert.Equal(t, "https://abc/oidc/accounts/xyz", key) + return &oauth2.Token{ + AccessToken: "bcd", + Expiry: time.Now().Add(1 * time.Minute), + }, nil + }, + }, + } + defer p.Close() + tok, err := p.Load(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "bcd", tok.AccessToken) + assert.Equal(t, "", tok.RefreshToken) +} + +func useInsecureOAuthHttpClientForTests(ctx context.Context) context.Context { + return context.WithValue(ctx, oauth2.HTTPClient, &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + }) +} + +func TestLoadRefresh(t *testing.T) { + qa.HTTPFixtures{ + { + Method: "POST", + Resource: "/oidc/accounts/xyz/v1/token", + Response: `access_token=refreshed&refresh_token=def`, + }, + }.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) { + ctx = useInsecureOAuthHttpClientForTests(ctx) + expectedKey := fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host) + p := &PersistentAuth{ + Host: c.Config.Host, + AccountID: "xyz", + cache: &tokenCacheMock{ + lookup: func(key string) (*oauth2.Token, error) { + assert.Equal(t, expectedKey, key) + return &oauth2.Token{ + AccessToken: "expired", + RefreshToken: "cde", + Expiry: time.Now().Add(-1 * time.Minute), + }, nil + }, + store: func(key string, tok *oauth2.Token) error { + assert.Equal(t, expectedKey, key) + assert.Equal(t, "def", tok.RefreshToken) + return nil + }, + }, + } + defer p.Close() + tok, err := p.Load(ctx) + assert.NoError(t, err) + assert.Equal(t, "refreshed", tok.AccessToken) + assert.Equal(t, "", tok.RefreshToken) + }) +} + +func TestChallenge(t *testing.T) { + qa.HTTPFixtures{ + { + Method: "POST", + Resource: "/oidc/accounts/xyz/v1/token", + Response: `access_token=__THAT__&refresh_token=__SOMETHING__`, + }, + }.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) { + ctx = useInsecureOAuthHttpClientForTests(ctx) + expectedKey := fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host) + + browserOpened := make(chan string) + p := &PersistentAuth{ + Host: c.Config.Host, + AccountID: "xyz", + browser: func(redirect string) error { + u, err := url.ParseRequestURI(redirect) + if err != nil { + return err + } + assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path) + // for now we're ignoring asserting the fields of the redirect + query := u.Query() + browserOpened <- query.Get("state") + return nil + }, + cache: &tokenCacheMock{ + store: func(key string, tok *oauth2.Token) error { + assert.Equal(t, expectedKey, key) + assert.Equal(t, "__SOMETHING__", tok.RefreshToken) + return nil + }, + }, + } + defer p.Close() + + errc := make(chan error) + go func() { + errc <- p.Challenge(ctx) + }() + + state := <-browserOpened + resp, err := http.Get(fmt.Sprintf("http://%s?code=__THIS__&state=%s", appRedirectAddr, state)) + assert.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + + err = <-errc + assert.NoError(t, err) + }) +} + +func TestChallengeFailed(t *testing.T) { + qa.HTTPFixtures{}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) { + ctx = useInsecureOAuthHttpClientForTests(ctx) + + browserOpened := make(chan string) + p := &PersistentAuth{ + Host: c.Config.Host, + AccountID: "xyz", + browser: func(redirect string) error { + u, err := url.ParseRequestURI(redirect) + if err != nil { + return err + } + assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path) + // for now we're ignoring asserting the fields of the redirect + query := u.Query() + browserOpened <- query.Get("state") + return nil + }, + } + defer p.Close() + + errc := make(chan error) + go func() { + errc <- p.Challenge(ctx) + }() + + <-browserOpened + resp, err := http.Get(fmt.Sprintf( + "http://%s?error=access_denied&error_description=Policy%%20evaluation%%20failed%%20for%%20this%%20request", + appRedirectAddr)) + assert.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + + err = <-errc + assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request") + }) +} diff --git a/libs/auth/page.tmpl b/libs/auth/page.tmpl new file mode 100644 index 00000000..00dfd64e --- /dev/null +++ b/libs/auth/page.tmpl @@ -0,0 +1,102 @@ + + + + + {{if .Error }}{{ .Error | title }}{{ else }}Success{{end}} + + + + + + + +
+
+ + +
{{ .Error | title }}
+
{{ .ErrorDescription }}
+ +
Authenticated
+
Go to {{.Host}}
+ +
+ You can close this tab. Or go to documentation +
+
+
+ + \ No newline at end of file diff --git a/main.go b/main.go index 472f49ba..1a6958fb 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( _ "github.com/databricks/bricks/cmd/api" + _ "github.com/databricks/bricks/cmd/auth" _ "github.com/databricks/bricks/cmd/bundle" _ "github.com/databricks/bricks/cmd/bundle/debug" _ "github.com/databricks/bricks/cmd/bundle/debug/deploy"