mirror of https://github.com/databricks/cli.git
Added `bricks auth login` and `bricks auth token` (#158)
# Auth challenge (happy path) Simplified description of [PKCE](https://oauth.net/2/pkce/) implementation: ```mermaid sequenceDiagram autonumber actor User User ->> CLI: type `bricks auth login HOST` CLI ->>+ HOST: request OIDC endpoints HOST ->>- CLI: auth & token endpoints CLI ->> CLI: start embedded server to consume redirects (lock) CLI -->>+ Auth Endpoint: open browser with RND1 + SHA256(RND2) User ->>+ Auth Endpoint: Go through SSO Auth Endpoint ->>- CLI: AUTH CODE + 'RND1 (redirect) CLI ->>+ Token Endpoint: Exchange: AUTH CODE + RND2 Token Endpoint ->>- CLI: Access Token (JWT) + refresh + expiry CLI ->> Token cache: Save Access Token (JWT) + refresh + expiry CLI ->> User: success ``` # Token refresh (happy path) ```mermaid sequenceDiagram autonumber actor User User ->> CLI: type `bricks token HOST` CLI ->> CLI: acquire lock (same local addr as redirect server) CLI ->>+ Token cache: read token critical token not expired Token cache ->>- User: JWT (without refresh) option token is expired CLI ->>+ HOST: request OIDC endpoints HOST ->>- CLI: auth & token endpoints CLI ->>+ Token Endpoint: refresh token Token Endpoint ->>- CLI: JWT (refreshed) CLI ->> Token cache: save JWT (refreshed) CLI ->> User: JWT (refreshed) option no auth for host CLI -X User: no auth configured end ```
This commit is contained in:
parent
a59136f77f
commit
b87b4b0f40
|
@ -0,0 +1,5 @@
|
|||
{
|
||||
"recommendations": [
|
||||
"bierner.markdown-mermaid"
|
||||
]
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
# Auth challenge (happy path)
|
||||
|
||||
Simplified description of [PKCE](https://oauth.net/2/pkce/) implementation:
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
actor User
|
||||
|
||||
User ->> CLI: type `bricks auth login HOST`
|
||||
CLI ->>+ HOST: request OIDC endpoints
|
||||
HOST ->>- CLI: auth & token endpoints
|
||||
CLI ->> CLI: start embedded server to consume redirects (lock)
|
||||
CLI -->>+ Auth Endpoint: open browser with RND1 + SHA256(RND2)
|
||||
|
||||
User ->>+ Auth Endpoint: Go through SSO
|
||||
Auth Endpoint ->>- CLI: AUTH CODE + 'RND1 (redirect)
|
||||
|
||||
CLI ->>+ Token Endpoint: Exchange: AUTH CODE + RND2
|
||||
Token Endpoint ->>- CLI: Access Token (JWT) + refresh + expiry
|
||||
CLI ->> Token cache: Save Access Token (JWT) + refresh + expiry
|
||||
CLI ->> User: success
|
||||
```
|
||||
|
||||
# Token refresh (happy path)
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
actor User
|
||||
|
||||
User ->> CLI: type `bricks token HOST`
|
||||
|
||||
CLI ->> CLI: acquire lock (same local addr as redirect server)
|
||||
CLI ->>+ Token cache: read token
|
||||
|
||||
critical token not expired
|
||||
Token cache ->>- User: JWT (without refresh)
|
||||
|
||||
option token is expired
|
||||
CLI ->>+ HOST: request OIDC endpoints
|
||||
HOST ->>- CLI: auth & token endpoints
|
||||
CLI ->>+ Token Endpoint: refresh token
|
||||
Token Endpoint ->>- CLI: JWT (refreshed)
|
||||
CLI ->> Token cache: save JWT (refreshed)
|
||||
CLI ->> User: JWT (refreshed)
|
||||
|
||||
option no auth for host
|
||||
CLI -X User: no auth configured
|
||||
end
|
||||
```
|
|
@ -0,0 +1,20 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"github.com/databricks/bricks/cmd/root"
|
||||
"github.com/databricks/bricks/libs/auth"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var authCmd = &cobra.Command{
|
||||
Use: "auth",
|
||||
Short: "Authentication related commands",
|
||||
}
|
||||
|
||||
var perisistentAuth auth.PersistentAuth
|
||||
|
||||
func init() {
|
||||
root.RootCmd.AddCommand(authCmd)
|
||||
authCmd.PersistentFlags().StringVar(&perisistentAuth.Host, "host", perisistentAuth.Host, "Databricks Host")
|
||||
authCmd.PersistentFlags().StringVar(&perisistentAuth.AccountID, "account-id", perisistentAuth.AccountID, "Databricks Account ID")
|
||||
}
|
|
@ -0,0 +1,143 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/databricks/databricks-sdk-go/config"
|
||||
"github.com/spf13/cobra"
|
||||
"gopkg.in/ini.v1"
|
||||
)
|
||||
|
||||
func canonicalHost(host string) (string, error) {
|
||||
parsedHost, err := url.Parse(host)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// If the host is empty, assume the scheme wasn't included.
|
||||
if parsedHost.Host == "" {
|
||||
return fmt.Sprintf("https://%s", host), nil
|
||||
}
|
||||
return fmt.Sprintf("https://%s", parsedHost.Host), nil
|
||||
}
|
||||
|
||||
var ErrNoMatchingProfiles = errors.New("no matching profiles found")
|
||||
|
||||
func resolveSection(cfg *config.Config, iniFile *ini.File) (*ini.Section, error) {
|
||||
var candidates []*ini.Section
|
||||
configuredHost, err := canonicalHost(cfg.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, section := range iniFile.Sections() {
|
||||
hash := section.KeysHash()
|
||||
host, ok := hash["host"]
|
||||
if !ok {
|
||||
// if host is not set
|
||||
continue
|
||||
}
|
||||
canonical, err := canonicalHost(host)
|
||||
if err != nil {
|
||||
// we're fine with other corrupt profiles
|
||||
continue
|
||||
}
|
||||
if canonical != configuredHost {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, section)
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return nil, ErrNoMatchingProfiles
|
||||
}
|
||||
// in the real situations, we don't expect this to happen often
|
||||
// (if not at all), hence we don't trim the list
|
||||
if len(candidates) > 1 {
|
||||
var profiles []string
|
||||
for _, v := range candidates {
|
||||
profiles = append(profiles, v.Name())
|
||||
}
|
||||
return nil, fmt.Errorf("%s match %s in %s",
|
||||
strings.Join(profiles, " and "), cfg.Host, cfg.ConfigFile)
|
||||
}
|
||||
return candidates[0], nil
|
||||
}
|
||||
|
||||
func loadFromDatabricksCfg(cfg *config.Config) error {
|
||||
iniFile, err := getDatabricksCfg()
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
// it's fine not to have ~/.databrickscfg
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
profile, err := resolveSection(cfg, iniFile)
|
||||
if err == ErrNoMatchingProfiles {
|
||||
// it's also fine for Azure CLI or Bricks CLI, which
|
||||
// are resolved by unified auth handling in the Go SDK.
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.Profile = profile.Name()
|
||||
return nil
|
||||
}
|
||||
|
||||
var envCmd = &cobra.Command{
|
||||
Use: "env",
|
||||
Short: "Get env",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := &config.Config{
|
||||
Host: host,
|
||||
Profile: profile,
|
||||
}
|
||||
if profile != "" {
|
||||
cfg.Profile = profile
|
||||
} else if cfg.Host == "" {
|
||||
cfg.Profile = "DEFAULT"
|
||||
} else if err := loadFromDatabricksCfg(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
// Go SDK is lazy loaded because of Terraform semantics,
|
||||
// so we're creating a dummy HTTP request as a placeholder
|
||||
// for headers.
|
||||
r := &http.Request{Header: http.Header{}}
|
||||
err := cfg.Authenticate(r.WithContext(cmd.Context()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
vars := map[string]string{}
|
||||
for _, a := range config.ConfigAttributes {
|
||||
if a.IsZero(cfg) {
|
||||
continue
|
||||
}
|
||||
envValue := a.GetString(cfg)
|
||||
for _, envName := range a.EnvVars {
|
||||
vars[envName] = envValue
|
||||
}
|
||||
}
|
||||
raw, err := json.MarshalIndent(map[string]any{
|
||||
"env": vars,
|
||||
}, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.OutOrStdout().Write(raw)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var host string
|
||||
var profile string
|
||||
|
||||
func init() {
|
||||
authCmd.AddCommand(envCmd)
|
||||
envCmd.Flags().StringVar(&host, "host", host, "Hostname to get auth env for")
|
||||
envCmd.Flags().StringVar(&profile, "profile", profile, "Profile to get auth env for")
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/databricks/bricks/libs/auth"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var loginTimeout time.Duration
|
||||
|
||||
var loginCmd = &cobra.Command{
|
||||
Use: "login [HOST]",
|
||||
Short: "Authenticate this machine",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if perisistentAuth.Host == "" && len(args) == 1 {
|
||||
perisistentAuth.Host = args[0]
|
||||
}
|
||||
defer perisistentAuth.Close()
|
||||
ctx, cancel := context.WithTimeout(cmd.Context(), loginTimeout)
|
||||
defer cancel()
|
||||
return perisistentAuth.Challenge(ctx)
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
authCmd.AddCommand(loginCmd)
|
||||
loginCmd.Flags().DurationVar(&loginTimeout, "timeout", auth.DefaultTimeout,
|
||||
"Timeout for completing login challenge in the browser")
|
||||
}
|
|
@ -0,0 +1,131 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/databricks/databricks-sdk-go"
|
||||
"github.com/databricks/databricks-sdk-go/config"
|
||||
"github.com/spf13/cobra"
|
||||
"gopkg.in/ini.v1"
|
||||
)
|
||||
|
||||
func getDatabricksCfg() (*ini.File, error) {
|
||||
configFile := os.Getenv("DATABRICKS_CONFIG_FILE")
|
||||
if configFile == "" {
|
||||
configFile = "~/.databrickscfg"
|
||||
}
|
||||
if strings.HasPrefix(configFile, "~") {
|
||||
homedir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot find homedir: %w", err)
|
||||
}
|
||||
configFile = filepath.Join(homedir, configFile[1:])
|
||||
}
|
||||
return ini.Load(configFile)
|
||||
}
|
||||
|
||||
type profileMetadata struct {
|
||||
Name string `json:"name"`
|
||||
Host string `json:"host,omitempty"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
Cloud string `json:"cloud"`
|
||||
AuthType string `json:"auth_type"`
|
||||
Valid bool `json:"valid"`
|
||||
}
|
||||
|
||||
func (c *profileMetadata) IsEmpty() bool {
|
||||
return c.Host == "" && c.AccountID == ""
|
||||
}
|
||||
|
||||
func (c *profileMetadata) Load(ctx context.Context) {
|
||||
// TODO: disable config loaders other than configfile
|
||||
cfg := &config.Config{Profile: c.Name}
|
||||
_ = cfg.EnsureResolved()
|
||||
if cfg.IsAws() {
|
||||
c.Cloud = "aws"
|
||||
} else if cfg.IsAzure() {
|
||||
c.Cloud = "azure"
|
||||
} else if cfg.IsGcp() {
|
||||
c.Cloud = "gcp"
|
||||
}
|
||||
if cfg.IsAccountClient() {
|
||||
a, err := databricks.NewAccountClient((*databricks.Config)(cfg))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = a.Workspaces.List(ctx)
|
||||
c.AuthType = cfg.AuthType
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.Valid = true
|
||||
} else {
|
||||
w, err := databricks.NewWorkspaceClient((*databricks.Config)(cfg))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = w.Tokens.ListAll(ctx)
|
||||
c.AuthType = cfg.AuthType
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.Valid = true
|
||||
}
|
||||
// set host again, this time normalized
|
||||
c.Host = cfg.Host
|
||||
}
|
||||
|
||||
var profilesCmd = &cobra.Command{
|
||||
Use: "profiles",
|
||||
Short: "Lists profiles from ~/.databrickscfg",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
iniFile, err := getDatabricksCfg()
|
||||
if os.IsNotExist(err) {
|
||||
// early return for non-configured machines
|
||||
return errors.New("~/.databrickcfg not found on current host")
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot parse config file: %w", err)
|
||||
}
|
||||
var profiles []*profileMetadata
|
||||
var wg sync.WaitGroup
|
||||
for _, v := range iniFile.Sections() {
|
||||
hash := v.KeysHash()
|
||||
profile := &profileMetadata{
|
||||
Name: v.Name(),
|
||||
Host: hash["host"],
|
||||
AccountID: hash["account_id"],
|
||||
}
|
||||
if profile.IsEmpty() {
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
// load more information about profile
|
||||
profile.Load(cmd.Context())
|
||||
wg.Done()
|
||||
}()
|
||||
profiles = append(profiles, profile)
|
||||
}
|
||||
wg.Wait()
|
||||
raw, err := json.MarshalIndent(map[string]any{
|
||||
"profiles": profiles,
|
||||
}, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.OutOrStdout().Write(raw)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
authCmd.AddCommand(profilesCmd)
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/databricks/bricks/libs/auth"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var tokenTimeout time.Duration
|
||||
|
||||
var tokenCmd = &cobra.Command{
|
||||
Use: "token [HOST]",
|
||||
Short: "Get authentication token",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if perisistentAuth.Host == "" && len(args) == 1 {
|
||||
perisistentAuth.Host = args[0]
|
||||
}
|
||||
defer perisistentAuth.Close()
|
||||
ctx, cancel := context.WithTimeout(cmd.Context(), tokenTimeout)
|
||||
defer cancel()
|
||||
t, err := perisistentAuth.Load(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
raw, err := json.MarshalIndent(t, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.OutOrStdout().Write(raw)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
authCmd.AddCommand(tokenCmd)
|
||||
tokenCmd.Flags().DurationVar(&tokenTimeout, "timeout", auth.DefaultTimeout,
|
||||
"Timeout for acquiring a token.")
|
||||
}
|
4
go.mod
4
go.mod
|
@ -48,9 +48,9 @@ require (
|
|||
github.com/spf13/pflag v1.0.5
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
golang.org/x/net v0.1.0 // indirect
|
||||
golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 // indirect
|
||||
golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783
|
||||
golang.org/x/sys v0.1.0 // indirect
|
||||
golang.org/x/text v0.5.0 // indirect
|
||||
golang.org/x/text v0.5.0
|
||||
golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac // indirect
|
||||
google.golang.org/api v0.105.0 // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
// where the token cache is stored
|
||||
tokenCacheFile = ".databricks/token-cache.json"
|
||||
|
||||
// only the owner of the file has full execute, read, and write access
|
||||
ownerExecReadWrite = 0o700
|
||||
|
||||
// only the owner of the file has full read and write access
|
||||
ownerReadWrite = 0o600
|
||||
|
||||
// format versioning leaves some room for format improvement
|
||||
tokenCacheVersion = 1
|
||||
)
|
||||
|
||||
var ErrNotConfigured = errors.New("databricks OAuth is not configured for this host")
|
||||
|
||||
// this implementation requires the calling code to do a machine-wide lock,
|
||||
// otherwise the file might get corrupt.
|
||||
type TokenCache struct {
|
||||
Version int `json:"version"`
|
||||
Tokens map[string]*oauth2.Token `json:"tokens"`
|
||||
|
||||
fileLocation string
|
||||
}
|
||||
|
||||
func (c *TokenCache) Store(key string, t *oauth2.Token) error {
|
||||
err := c.load()
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
dir := filepath.Dir(c.fileLocation)
|
||||
err = os.MkdirAll(dir, ownerExecReadWrite)
|
||||
if err != nil {
|
||||
return fmt.Errorf("mkdir: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("load: %w", err)
|
||||
}
|
||||
c.Version = tokenCacheVersion
|
||||
if c.Tokens == nil {
|
||||
c.Tokens = map[string]*oauth2.Token{}
|
||||
}
|
||||
c.Tokens[key] = t
|
||||
raw, err := json.MarshalIndent(c, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal: %w", err)
|
||||
}
|
||||
return os.WriteFile(c.fileLocation, raw, ownerReadWrite)
|
||||
}
|
||||
|
||||
func (c *TokenCache) Lookup(key string) (*oauth2.Token, error) {
|
||||
err := c.load()
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, ErrNotConfigured
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("load: %w", err)
|
||||
}
|
||||
t, ok := c.Tokens[key]
|
||||
if !ok {
|
||||
return nil, ErrNotConfigured
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (c *TokenCache) location() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("home: %w", err)
|
||||
}
|
||||
return filepath.Join(home, tokenCacheFile), nil
|
||||
}
|
||||
|
||||
func (c *TokenCache) load() error {
|
||||
loc, err := c.location()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.fileLocation = loc
|
||||
raw, err := os.ReadFile(loc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read: %w", err)
|
||||
}
|
||||
err = json.Unmarshal(raw, c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse: %w", err)
|
||||
}
|
||||
if c.Version != tokenCacheVersion {
|
||||
// in the later iterations we could do state upgraders,
|
||||
// so that we transform token cache from v1 to v2 without
|
||||
// losing the tokens and asking the user to re-authenticate.
|
||||
return fmt.Errorf("needs version %d, got version %d",
|
||||
tokenCacheVersion, c.Version)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var homeEnvVar = "HOME"
|
||||
|
||||
func init() {
|
||||
if runtime.GOOS == "windows" {
|
||||
homeEnvVar = "USERPROFILE"
|
||||
}
|
||||
}
|
||||
|
||||
func setup(t *testing.T) string {
|
||||
tempHomeDir := t.TempDir()
|
||||
t.Setenv(homeEnvVar, tempHomeDir)
|
||||
return tempHomeDir
|
||||
}
|
||||
|
||||
func TestStoreAndLookup(t *testing.T) {
|
||||
setup(t)
|
||||
c := &TokenCache{}
|
||||
err := c.Store("x", &oauth2.Token{
|
||||
AccessToken: "abc",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = c.Store("y", &oauth2.Token{
|
||||
AccessToken: "bcd",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
l := &TokenCache{}
|
||||
tok, err := l.Lookup("x")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "abc", tok.AccessToken)
|
||||
assert.Equal(t, 2, len(l.Tokens))
|
||||
|
||||
_, err = l.Lookup("z")
|
||||
assert.Equal(t, ErrNotConfigured, err)
|
||||
}
|
||||
|
||||
func TestNoCacheFileReturnsErrNotConfigured(t *testing.T) {
|
||||
setup(t)
|
||||
l := &TokenCache{}
|
||||
_, err := l.Lookup("x")
|
||||
assert.Equal(t, ErrNotConfigured, err)
|
||||
}
|
||||
|
||||
func TestLoadCorruptFile(t *testing.T) {
|
||||
home := setup(t)
|
||||
f := filepath.Join(home, tokenCacheFile)
|
||||
err := os.MkdirAll(filepath.Dir(f), ownerExecReadWrite)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(f, []byte("abc"), ownerExecReadWrite)
|
||||
require.NoError(t, err)
|
||||
|
||||
l := &TokenCache{}
|
||||
_, err = l.Lookup("x")
|
||||
assert.EqualError(t, err, "load: parse: invalid character 'a' looking for beginning of value")
|
||||
}
|
||||
|
||||
func TestLoadWrongVersion(t *testing.T) {
|
||||
home := setup(t)
|
||||
f := filepath.Join(home, tokenCacheFile)
|
||||
err := os.MkdirAll(filepath.Dir(f), ownerExecReadWrite)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(f, []byte(`{"version": 823, "things": []}`), ownerExecReadWrite)
|
||||
require.NoError(t, err)
|
||||
|
||||
l := &TokenCache{}
|
||||
_, err = l.Lookup("x")
|
||||
assert.EqualError(t, err, "load: needs version 1, got version 823")
|
||||
}
|
||||
|
||||
func TestDevNull(t *testing.T) {
|
||||
t.Setenv(homeEnvVar, "/dev/null")
|
||||
l := &TokenCache{}
|
||||
_, err := l.Lookup("x")
|
||||
// macOS/Linux: load: read: open /dev/null/.databricks/token-cache.json:
|
||||
// windows: databricks OAuth is not configured for this host
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoreOnDev(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.SkipNow()
|
||||
}
|
||||
t.Setenv(homeEnvVar, "/dev")
|
||||
c := &TokenCache{}
|
||||
err := c.Store("x", &oauth2.Token{
|
||||
AccessToken: "abc",
|
||||
})
|
||||
// Linux: permission denied
|
||||
// macOS: read-only file system
|
||||
assert.Error(t, err)
|
||||
}
|
|
@ -0,0 +1,102 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
//go:embed page.tmpl
|
||||
var pageTmpl string
|
||||
|
||||
type oauthResult struct {
|
||||
Error string
|
||||
ErrorDescription string
|
||||
Host string
|
||||
State string
|
||||
Code string
|
||||
}
|
||||
|
||||
type callbackServer struct {
|
||||
ln net.Listener
|
||||
srv http.Server
|
||||
ctx context.Context
|
||||
a *PersistentAuth
|
||||
renderErrCh chan error
|
||||
feedbackCh chan oauthResult
|
||||
tmpl *template.Template
|
||||
}
|
||||
|
||||
func newCallback(ctx context.Context, a *PersistentAuth) (*callbackServer, error) {
|
||||
tmpl, err := template.New("page").Funcs(template.FuncMap{
|
||||
"title": func(in string) string {
|
||||
title := cases.Title(language.English)
|
||||
return title.String(strings.ReplaceAll(in, "_", " "))
|
||||
},
|
||||
}).Parse(pageTmpl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cb := &callbackServer{
|
||||
feedbackCh: make(chan oauthResult),
|
||||
renderErrCh: make(chan error),
|
||||
tmpl: tmpl,
|
||||
ctx: ctx,
|
||||
ln: a.ln,
|
||||
a: a,
|
||||
}
|
||||
cb.srv.Handler = cb
|
||||
go cb.srv.Serve(cb.ln)
|
||||
return cb, nil
|
||||
}
|
||||
|
||||
func (cb *callbackServer) Close() error {
|
||||
return cb.srv.Close()
|
||||
}
|
||||
|
||||
// ServeHTTP renders page.html template
|
||||
func (cb *callbackServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
res := oauthResult{
|
||||
Error: r.FormValue("error"),
|
||||
ErrorDescription: r.FormValue("error_description"),
|
||||
Code: r.FormValue("code"),
|
||||
State: r.FormValue("state"),
|
||||
Host: cb.a.Host,
|
||||
}
|
||||
if res.Error != "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
err := cb.tmpl.Execute(w, res)
|
||||
if err != nil {
|
||||
cb.renderErrCh <- err
|
||||
}
|
||||
cb.feedbackCh <- res
|
||||
}
|
||||
|
||||
// Handler opens up a browser waits for redirect to come back from the identity provider
|
||||
func (cb *callbackServer) Handler(authCodeURL string) (string, string, error) {
|
||||
err := cb.a.browser(authCodeURL)
|
||||
if err != nil {
|
||||
fmt.Printf("Please open %s in the browser to continue authentication", authCodeURL)
|
||||
}
|
||||
select {
|
||||
case <-cb.ctx.Done():
|
||||
return "", "", cb.ctx.Err()
|
||||
case renderErr := <-cb.renderErrCh:
|
||||
return "", "", renderErr
|
||||
case res := <-cb.feedbackCh:
|
||||
if res.Error != "" {
|
||||
return "", "", fmt.Errorf("%s: %s", res.Error, res.ErrorDescription)
|
||||
}
|
||||
return res.Code, res.State, nil
|
||||
}
|
||||
}
|
|
@ -0,0 +1,268 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
_ "embed"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/databricks/bricks/libs/auth/cache"
|
||||
"github.com/databricks/databricks-sdk-go/retries"
|
||||
"github.com/pkg/browser"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/authhandler"
|
||||
)
|
||||
|
||||
const (
|
||||
// these values are predefined by Databricks as a public client
|
||||
// and is specific to this application only. Using these values
|
||||
// for other applications is not allowed.
|
||||
appClientID = "databricks-cli"
|
||||
appRedirectAddr = "localhost:8020"
|
||||
|
||||
// maximum amount of time to acquire listener on appRedirectAddr
|
||||
DefaultTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
var ( // Databricks SDK API: `databricks OAuth is not` will be checked for presence
|
||||
ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host")
|
||||
ErrNotConfigured = errors.New("databricks OAuth is not configured for this host")
|
||||
ErrFetchCredentials = errors.New("cannot fetch credentials")
|
||||
)
|
||||
|
||||
type PersistentAuth struct {
|
||||
Host string
|
||||
AccountID string
|
||||
|
||||
http httpGet
|
||||
cache tokenCache
|
||||
ln net.Listener
|
||||
browser func(string) error
|
||||
}
|
||||
|
||||
type httpGet interface {
|
||||
Get(string) (*http.Response, error)
|
||||
}
|
||||
|
||||
type tokenCache interface {
|
||||
Store(key string, t *oauth2.Token) error
|
||||
Lookup(key string) (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) {
|
||||
err := a.init(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init: %w", err)
|
||||
}
|
||||
// lookup token identified by host (and possibly the account id)
|
||||
key := a.key()
|
||||
t, err := a.cache.Lookup(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cache: %w", err)
|
||||
}
|
||||
// early return for valid tokens
|
||||
if t.Valid() {
|
||||
// do not print refresh token to end-user
|
||||
t.RefreshToken = ""
|
||||
return t, nil
|
||||
}
|
||||
// OAuth2 config is invoked only for expired tokens to speed up
|
||||
// the happy path in the token retrieval
|
||||
cfg, err := a.oauth2Config()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// eagerly refresh token
|
||||
refreshed, err := cfg.TokenSource(ctx, t).Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token refresh: %w", err)
|
||||
}
|
||||
err = a.cache.Store(key, refreshed)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cache refresh: %w", err)
|
||||
}
|
||||
// do not print refresh token to end-user
|
||||
refreshed.RefreshToken = ""
|
||||
return refreshed, nil
|
||||
}
|
||||
|
||||
func (a *PersistentAuth) Challenge(ctx context.Context) error {
|
||||
err := a.init(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init: %w", err)
|
||||
}
|
||||
cfg, err := a.oauth2Config()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cb, err := newCallback(ctx, a)
|
||||
if err != nil {
|
||||
return fmt.Errorf("callback server: %w", err)
|
||||
}
|
||||
defer cb.Close()
|
||||
state, pkce := a.stateAndPKCE()
|
||||
ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce)
|
||||
t, err := ts.Token()
|
||||
if err != nil {
|
||||
return fmt.Errorf("authorize: %w", err)
|
||||
}
|
||||
// cache token identified by host (and possibly the account id)
|
||||
err = a.cache.Store(a.key(), t)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PersistentAuth) init(ctx context.Context) error {
|
||||
if a.Host == "" && a.AccountID == "" {
|
||||
return ErrFetchCredentials
|
||||
}
|
||||
if a.http == nil {
|
||||
a.http = http.DefaultClient
|
||||
}
|
||||
if a.cache == nil {
|
||||
a.cache = &cache.TokenCache{}
|
||||
}
|
||||
if a.browser == nil {
|
||||
a.browser = browser.OpenURL
|
||||
}
|
||||
// try acquire listener, which we also use as a machine-local
|
||||
// exclusive lock to prevent token cache corruption in the scope
|
||||
// of developer machine, where this command runs.
|
||||
listener, err := retries.Poll(ctx, DefaultTimeout,
|
||||
func() (*net.Listener, *retries.Err) {
|
||||
var lc net.ListenConfig
|
||||
l, err := lc.Listen(ctx, "tcp", appRedirectAddr)
|
||||
if err != nil {
|
||||
return nil, retries.Continue(err)
|
||||
}
|
||||
return &l, nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("listener: %w", err)
|
||||
}
|
||||
a.ln = *listener
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PersistentAuth) Close() error {
|
||||
if a.ln == nil {
|
||||
return nil
|
||||
}
|
||||
return a.ln.Close()
|
||||
}
|
||||
|
||||
func (a *PersistentAuth) oidcEndpoints() (*oauthAuthorizationServer, error) {
|
||||
prefix := a.key()
|
||||
if a.AccountID != "" {
|
||||
return &oauthAuthorizationServer{
|
||||
AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix),
|
||||
TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix),
|
||||
}, nil
|
||||
}
|
||||
oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", prefix)
|
||||
oidcResponse, err := a.http.Get(oidc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch .well-known: %w", err)
|
||||
}
|
||||
if oidcResponse.StatusCode != 200 {
|
||||
return nil, ErrOAuthNotSupported
|
||||
}
|
||||
if oidcResponse.Body == nil {
|
||||
return nil, fmt.Errorf("fetch .well-known: empty body")
|
||||
}
|
||||
defer oidcResponse.Body.Close()
|
||||
raw, err := io.ReadAll(oidcResponse.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read .well-known: %w", err)
|
||||
}
|
||||
var oauthEndpoints oauthAuthorizationServer
|
||||
err = json.Unmarshal(raw, &oauthEndpoints)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse .well-known: %w", err)
|
||||
}
|
||||
return &oauthEndpoints, nil
|
||||
}
|
||||
|
||||
func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) {
|
||||
// in this iteration of CLI, we're using all scopes by default,
|
||||
// because tools like CLI and Terraform do use all apis. This
|
||||
// decision may be reconsidered later, once we have a proper
|
||||
// taxonomy of all scopes ready and implemented.
|
||||
scopes := []string{
|
||||
"offline_access",
|
||||
"unity-catalog",
|
||||
"accounts",
|
||||
"clusters",
|
||||
"mlflow",
|
||||
"scim",
|
||||
"sql",
|
||||
}
|
||||
if a.AccountID != "" {
|
||||
scopes = []string{
|
||||
"offline_access",
|
||||
"accounts",
|
||||
}
|
||||
}
|
||||
endpoints, err := a.oidcEndpoints()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: %w", err)
|
||||
}
|
||||
return &oauth2.Config{
|
||||
ClientID: appClientID,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: endpoints.AuthorizationEndpoint,
|
||||
TokenURL: endpoints.TokenEndpoint,
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr),
|
||||
Scopes: scopes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// key is currently used for two purposes: OIDC URL prefix and token cache key.
|
||||
// once we decide to start storing scopes in the token cache, we should change
|
||||
// this approach.
|
||||
func (a *PersistentAuth) key() string {
|
||||
a.Host = strings.TrimSuffix(a.Host, "/")
|
||||
if !strings.HasPrefix(a.Host, "http") {
|
||||
a.Host = fmt.Sprintf("https://%s", a.Host)
|
||||
}
|
||||
if a.AccountID != "" {
|
||||
return fmt.Sprintf("%s/oidc/accounts/%s", a.Host, a.AccountID)
|
||||
}
|
||||
return a.Host
|
||||
}
|
||||
|
||||
func (a *PersistentAuth) stateAndPKCE() (string, *authhandler.PKCEParams) {
|
||||
verifier := a.randomString(64)
|
||||
verifierSha256 := sha256.Sum256([]byte(verifier))
|
||||
challenge := base64.RawURLEncoding.EncodeToString(verifierSha256[:])
|
||||
return a.randomString(16), &authhandler.PKCEParams{
|
||||
Challenge: challenge,
|
||||
ChallengeMethod: "S256",
|
||||
Verifier: verifier,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *PersistentAuth) randomString(size int) string {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
raw := make([]byte, size)
|
||||
_, _ = rand.Read(raw)
|
||||
return base64.RawURLEncoding.EncodeToString(raw)
|
||||
}
|
||||
|
||||
type oauthAuthorizationServer struct {
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize
|
||||
TokenEndpoint string `json:"token_endpoint"` // ../v1/token
|
||||
}
|
|
@ -0,0 +1,235 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/databricks/databricks-sdk-go/client"
|
||||
"github.com/databricks/databricks-sdk-go/qa"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func TestOidcEndpointsForAccounts(t *testing.T) {
|
||||
p := &PersistentAuth{
|
||||
Host: "abc",
|
||||
AccountID: "xyz",
|
||||
}
|
||||
defer p.Close()
|
||||
s, err := p.oidcEndpoints()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", s.AuthorizationEndpoint)
|
||||
assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/token", s.TokenEndpoint)
|
||||
}
|
||||
|
||||
type mockGet func(url string) (*http.Response, error)
|
||||
|
||||
func (m mockGet) Get(url string) (*http.Response, error) {
|
||||
return m(url)
|
||||
}
|
||||
|
||||
func TestOidcForWorkspace(t *testing.T) {
|
||||
p := &PersistentAuth{
|
||||
Host: "abc",
|
||||
http: mockGet(func(url string) (*http.Response, error) {
|
||||
assert.Equal(t, "https://abc/oidc/.well-known/oauth-authorization-server", url)
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(strings.NewReader(`{
|
||||
"authorization_endpoint": "a",
|
||||
"token_endpoint": "b"
|
||||
}`)),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
defer p.Close()
|
||||
endpoints, err := p.oidcEndpoints()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "a", endpoints.AuthorizationEndpoint)
|
||||
assert.Equal(t, "b", endpoints.TokenEndpoint)
|
||||
}
|
||||
|
||||
type tokenCacheMock struct {
|
||||
store func(key string, t *oauth2.Token) error
|
||||
lookup func(key string) (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
func (m *tokenCacheMock) Store(key string, t *oauth2.Token) error {
|
||||
if m.store == nil {
|
||||
panic("no store mock")
|
||||
}
|
||||
return m.store(key, t)
|
||||
}
|
||||
|
||||
func (m *tokenCacheMock) Lookup(key string) (*oauth2.Token, error) {
|
||||
if m.lookup == nil {
|
||||
panic("no lookup mock")
|
||||
}
|
||||
return m.lookup(key)
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
p := &PersistentAuth{
|
||||
Host: "abc",
|
||||
AccountID: "xyz",
|
||||
cache: &tokenCacheMock{
|
||||
lookup: func(key string) (*oauth2.Token, error) {
|
||||
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
|
||||
return &oauth2.Token{
|
||||
AccessToken: "bcd",
|
||||
Expiry: time.Now().Add(1 * time.Minute),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
defer p.Close()
|
||||
tok, err := p.Load(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "bcd", tok.AccessToken)
|
||||
assert.Equal(t, "", tok.RefreshToken)
|
||||
}
|
||||
|
||||
func useInsecureOAuthHttpClientForTests(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, oauth2.HTTPClient, &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadRefresh(t *testing.T) {
|
||||
qa.HTTPFixtures{
|
||||
{
|
||||
Method: "POST",
|
||||
Resource: "/oidc/accounts/xyz/v1/token",
|
||||
Response: `access_token=refreshed&refresh_token=def`,
|
||||
},
|
||||
}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
|
||||
ctx = useInsecureOAuthHttpClientForTests(ctx)
|
||||
expectedKey := fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host)
|
||||
p := &PersistentAuth{
|
||||
Host: c.Config.Host,
|
||||
AccountID: "xyz",
|
||||
cache: &tokenCacheMock{
|
||||
lookup: func(key string) (*oauth2.Token, error) {
|
||||
assert.Equal(t, expectedKey, key)
|
||||
return &oauth2.Token{
|
||||
AccessToken: "expired",
|
||||
RefreshToken: "cde",
|
||||
Expiry: time.Now().Add(-1 * time.Minute),
|
||||
}, nil
|
||||
},
|
||||
store: func(key string, tok *oauth2.Token) error {
|
||||
assert.Equal(t, expectedKey, key)
|
||||
assert.Equal(t, "def", tok.RefreshToken)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
defer p.Close()
|
||||
tok, err := p.Load(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "refreshed", tok.AccessToken)
|
||||
assert.Equal(t, "", tok.RefreshToken)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChallenge(t *testing.T) {
|
||||
qa.HTTPFixtures{
|
||||
{
|
||||
Method: "POST",
|
||||
Resource: "/oidc/accounts/xyz/v1/token",
|
||||
Response: `access_token=__THAT__&refresh_token=__SOMETHING__`,
|
||||
},
|
||||
}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
|
||||
ctx = useInsecureOAuthHttpClientForTests(ctx)
|
||||
expectedKey := fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host)
|
||||
|
||||
browserOpened := make(chan string)
|
||||
p := &PersistentAuth{
|
||||
Host: c.Config.Host,
|
||||
AccountID: "xyz",
|
||||
browser: func(redirect string) error {
|
||||
u, err := url.ParseRequestURI(redirect)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path)
|
||||
// for now we're ignoring asserting the fields of the redirect
|
||||
query := u.Query()
|
||||
browserOpened <- query.Get("state")
|
||||
return nil
|
||||
},
|
||||
cache: &tokenCacheMock{
|
||||
store: func(key string, tok *oauth2.Token) error {
|
||||
assert.Equal(t, expectedKey, key)
|
||||
assert.Equal(t, "__SOMETHING__", tok.RefreshToken)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
defer p.Close()
|
||||
|
||||
errc := make(chan error)
|
||||
go func() {
|
||||
errc <- p.Challenge(ctx)
|
||||
}()
|
||||
|
||||
state := <-browserOpened
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s?code=__THIS__&state=%s", appRedirectAddr, state))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
err = <-errc
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChallengeFailed(t *testing.T) {
|
||||
qa.HTTPFixtures{}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
|
||||
ctx = useInsecureOAuthHttpClientForTests(ctx)
|
||||
|
||||
browserOpened := make(chan string)
|
||||
p := &PersistentAuth{
|
||||
Host: c.Config.Host,
|
||||
AccountID: "xyz",
|
||||
browser: func(redirect string) error {
|
||||
u, err := url.ParseRequestURI(redirect)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path)
|
||||
// for now we're ignoring asserting the fields of the redirect
|
||||
query := u.Query()
|
||||
browserOpened <- query.Get("state")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
defer p.Close()
|
||||
|
||||
errc := make(chan error)
|
||||
go func() {
|
||||
errc <- p.Challenge(ctx)
|
||||
}()
|
||||
|
||||
<-browserOpened
|
||||
resp, err := http.Get(fmt.Sprintf(
|
||||
"http://%s?error=access_denied&error_description=Policy%%20evaluation%%20failed%%20for%%20this%%20request",
|
||||
appRedirectAddr))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 400, resp.StatusCode)
|
||||
|
||||
err = <-errc
|
||||
assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request")
|
||||
})
|
||||
}
|
|
@ -0,0 +1,102 @@
|
|||
<!DOCTYPE html SYSTEM "http://www.thymeleaf.org/dtd/xhtml1-strict-thymeleaf-4.dtd">
|
||||
|
||||
<html xmlns="http://www.w3.org/1999/xhtml" xmlns:th="http://www.thymeleaf.org">
|
||||
<head>
|
||||
<title>{{if .Error }}{{ .Error | title }}{{ else }}Success{{end}}</title>
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" />
|
||||
<link href="https://fonts.googleapis.com/css2?family=DM+Sans&display=swap" rel="stylesheet" />
|
||||
|
||||
<style>
|
||||
html,
|
||||
body {
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: "DM Sans";
|
||||
font-style: normal;
|
||||
font-size: 14px;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
height: 100%;
|
||||
width: 100%;
|
||||
background: #f5f6f6;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.root-container {
|
||||
display: flex;
|
||||
height: 100%;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.info-container {
|
||||
width: 320px;
|
||||
box-shadow: 0px 2px 4px rgba(0, 0, 0, 0.1),
|
||||
0px 8px 25px rgba(0, 0, 0, 0.1);
|
||||
border-radius: 8px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
padding: 48px;
|
||||
background: #fff;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
text-align: center;
|
||||
gap: 24px;
|
||||
}
|
||||
|
||||
.title {
|
||||
font-weight: 600;
|
||||
font-size: 24px;
|
||||
line-height: 28px;
|
||||
}
|
||||
|
||||
a {
|
||||
color: #C4CCD6;
|
||||
}
|
||||
|
||||
a:hover {
|
||||
color: #90A5B1;
|
||||
}
|
||||
|
||||
.content {
|
||||
width: 300px;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.button {
|
||||
display: flex;
|
||||
background: #1B3139;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
height: 40px;
|
||||
width: 300px;
|
||||
border-radius: 4px;
|
||||
text-align: center;
|
||||
text-decoration: none;
|
||||
color: #ffffff !important;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="root-container">
|
||||
<div class="info-container">
|
||||
<img
|
||||
src=""
|
||||
/>
|
||||
<!-- {{ if .Error }} -->
|
||||
<div class="title">{{ .Error | title }}</div>
|
||||
<div class="content">{{ .ErrorDescription }}</div>
|
||||
<!-- {{ else }} -->
|
||||
<div class="title">Authenticated</div>
|
||||
<div class="content">Go to <a href="https://{{.Host}}">{{.Host}}</a></div>
|
||||
<!-- {{ end }} -->
|
||||
<div class="content">
|
||||
You can close this tab. Or go to <a href="https://docs.databricks.com/dev-tools/index-cli.html">documentation</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
Loading…
Reference in New Issue