mirror of https://github.com/databricks/cli.git
Merge ca08796f77
into 4b069bb6e1
This commit is contained in:
commit
20aea8fd03
|
@ -31,6 +31,7 @@ GCP: https://docs.gcp.databricks.com/dev-tools/auth/index.html`,
|
||||||
cmd.AddCommand(newProfilesCommand())
|
cmd.AddCommand(newProfilesCommand())
|
||||||
cmd.AddCommand(newTokenCommand(&perisistentAuth))
|
cmd.AddCommand(newTokenCommand(&perisistentAuth))
|
||||||
cmd.AddCommand(newDescribeCommand())
|
cmd.AddCommand(newDescribeCommand())
|
||||||
|
cmd.AddCommand(newLogoutCommand(&perisistentAuth))
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,110 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
|
||||||
|
"github.com/databricks/cli/libs/auth"
|
||||||
|
"github.com/databricks/cli/libs/auth/cache"
|
||||||
|
"github.com/databricks/cli/libs/cmdio"
|
||||||
|
"github.com/databricks/cli/libs/databrickscfg/profile"
|
||||||
|
"github.com/databricks/databricks-sdk-go/config"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
type logoutSession struct {
|
||||||
|
profile string
|
||||||
|
file config.File
|
||||||
|
persistentAuth *auth.PersistentAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *logoutSession) load(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth) error {
|
||||||
|
l.profile = profileName
|
||||||
|
l.persistentAuth = persistentAuth
|
||||||
|
iniFile, err := profile.DefaultProfiler.Get(ctx)
|
||||||
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return err
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("cannot parse config file: %w", err)
|
||||||
|
}
|
||||||
|
l.file = *iniFile
|
||||||
|
if err := l.setHostAndAccountIdFromProfile(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *logoutSession) setHostAndAccountIdFromProfile() error {
|
||||||
|
sectionMap, err := l.getConfigSectionMap()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if sectionMap["host"] == "" {
|
||||||
|
return fmt.Errorf("no host configured for profile %s", l.profile)
|
||||||
|
}
|
||||||
|
l.persistentAuth.Host = sectionMap["host"]
|
||||||
|
l.persistentAuth.AccountID = sectionMap["account_id"]
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *logoutSession) getConfigSectionMap() (map[string]string, error) {
|
||||||
|
section, err := l.file.GetSection(l.profile)
|
||||||
|
if err != nil {
|
||||||
|
return map[string]string{}, fmt.Errorf("profile does not exist in config file: %w", err)
|
||||||
|
}
|
||||||
|
return section.KeysHash(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear token from ~/.databricks/token-cache.json
|
||||||
|
func (l *logoutSession) clearTokenCache(ctx context.Context) error {
|
||||||
|
return l.persistentAuth.ClearToken(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLogoutCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "logout [PROFILE]",
|
||||||
|
Short: "Logout from specified profile",
|
||||||
|
Long: "Removes the OAuth token from the token-cache",
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.RunE = func(cmd *cobra.Command, args []string) error {
|
||||||
|
ctx := cmd.Context()
|
||||||
|
profileNameFromFlag := cmd.Flag("profile").Value.String()
|
||||||
|
// If both [PROFILE] and --profile are provided, return an error.
|
||||||
|
if len(args) > 0 && profileNameFromFlag != "" {
|
||||||
|
return fmt.Errorf("please only provide a profile as an argument or a flag, not both")
|
||||||
|
}
|
||||||
|
// Determine the profile name from either args or the flag.
|
||||||
|
profileName := profileNameFromFlag
|
||||||
|
if len(args) > 0 {
|
||||||
|
profileName = args[0]
|
||||||
|
}
|
||||||
|
// If the user has not specified a profile name, prompt for one.
|
||||||
|
if profileName == "" {
|
||||||
|
var err error
|
||||||
|
profileName, err = promptForProfile(ctx, persistentAuth.ProfileName())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer persistentAuth.Close()
|
||||||
|
logoutSession := &logoutSession{}
|
||||||
|
err := logoutSession.load(ctx, profileName, persistentAuth)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = logoutSession.clearTokenCache(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, cache.ErrNotConfigured) {
|
||||||
|
// It is OK to not have OAuth configured
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cmdio.LogString(ctx, fmt.Sprintf("Profile %s is logged out", profileName))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return cmd
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/databricks/cli/libs/auth"
|
||||||
|
"github.com/databricks/cli/libs/databrickscfg"
|
||||||
|
"github.com/databricks/databricks-sdk-go/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLogout_setHostAndAccountIdFromProfile(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
path := filepath.Join(t.TempDir(), "databrickscfg")
|
||||||
|
|
||||||
|
err := databrickscfg.SaveToProfile(ctx, &config.Config{
|
||||||
|
ConfigFile: path,
|
||||||
|
Profile: "abc",
|
||||||
|
Host: "https://foo",
|
||||||
|
Token: "xyz",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
iniFile, err := config.LoadFile(path)
|
||||||
|
require.NoError(t, err)
|
||||||
|
logout := &logoutSession{
|
||||||
|
profile: "abc",
|
||||||
|
file: *iniFile,
|
||||||
|
persistentAuth: &auth.PersistentAuth{},
|
||||||
|
}
|
||||||
|
err = logout.setHostAndAccountIdFromProfile()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, logout.persistentAuth.Host, "https://foo")
|
||||||
|
assert.Empty(t, logout.persistentAuth.AccountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogout_getConfigSectionMap(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
path := filepath.Join(t.TempDir(), "databrickscfg")
|
||||||
|
|
||||||
|
err := databrickscfg.SaveToProfile(ctx, &config.Config{
|
||||||
|
ConfigFile: path,
|
||||||
|
Profile: "abc",
|
||||||
|
Host: "https://foo",
|
||||||
|
Token: "xyz",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
iniFile, err := config.LoadFile(path)
|
||||||
|
require.NoError(t, err)
|
||||||
|
logout := &logoutSession{
|
||||||
|
profile: "abc",
|
||||||
|
file: *iniFile,
|
||||||
|
persistentAuth: &auth.PersistentAuth{},
|
||||||
|
}
|
||||||
|
configSectionMap, err := logout.getConfigSectionMap()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, configSectionMap["host"], "https://foo")
|
||||||
|
assert.Equal(t, configSectionMap["token"], "xyz")
|
||||||
|
}
|
|
@ -9,6 +9,7 @@ import (
|
||||||
type TokenCache interface {
|
type TokenCache interface {
|
||||||
Store(key string, t *oauth2.Token) error
|
Store(key string, t *oauth2.Token) error
|
||||||
Lookup(key string) (*oauth2.Token, error)
|
Lookup(key string) (*oauth2.Token, error)
|
||||||
|
Delete(key string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
var tokenCache int
|
var tokenCache int
|
||||||
|
|
|
@ -52,11 +52,7 @@ func (c *FileTokenCache) Store(key string, t *oauth2.Token) error {
|
||||||
c.Tokens = map[string]*oauth2.Token{}
|
c.Tokens = map[string]*oauth2.Token{}
|
||||||
}
|
}
|
||||||
c.Tokens[key] = t
|
c.Tokens[key] = t
|
||||||
raw, err := json.MarshalIndent(c, "", " ")
|
return c.write()
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("marshal: %w", err)
|
|
||||||
}
|
|
||||||
return os.WriteFile(c.fileLocation, raw, ownerReadWrite)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) {
|
func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) {
|
||||||
|
@ -73,6 +69,24 @@ func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) {
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *FileTokenCache) Delete(key string) error {
|
||||||
|
err := c.load()
|
||||||
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return ErrNotConfigured
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("load: %w", err)
|
||||||
|
}
|
||||||
|
if c.Tokens == nil {
|
||||||
|
c.Tokens = map[string]*oauth2.Token{}
|
||||||
|
}
|
||||||
|
_, ok := c.Tokens[key]
|
||||||
|
if !ok {
|
||||||
|
return ErrNotConfigured
|
||||||
|
}
|
||||||
|
delete(c.Tokens, key)
|
||||||
|
return c.write()
|
||||||
|
}
|
||||||
|
|
||||||
func (c *FileTokenCache) location() (string, error) {
|
func (c *FileTokenCache) location() (string, error) {
|
||||||
home, err := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -105,4 +119,12 @@ func (c *FileTokenCache) load() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *FileTokenCache) write() error {
|
||||||
|
raw, err := json.MarshalIndent(c, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal: %w", err)
|
||||||
|
}
|
||||||
|
return os.WriteFile(c.fileLocation, raw, ownerReadWrite)
|
||||||
|
}
|
||||||
|
|
||||||
var _ TokenCache = (*FileTokenCache)(nil)
|
var _ TokenCache = (*FileTokenCache)(nil)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package cache
|
package cache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
@ -103,3 +104,64 @@ func TestStoreOnDev(t *testing.T) {
|
||||||
// macOS: read-only file system
|
// macOS: read-only file system
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStoreAndDeleteKey(t *testing.T) {
|
||||||
|
setup(t)
|
||||||
|
c := &FileTokenCache{}
|
||||||
|
err := c.Store("x", &oauth2.Token{
|
||||||
|
AccessToken: "abc",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = c.Store("y", &oauth2.Token{
|
||||||
|
AccessToken: "bcd",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
l := &FileTokenCache{}
|
||||||
|
err = l.Delete("x")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, len(l.Tokens))
|
||||||
|
|
||||||
|
_, err = l.Lookup("x")
|
||||||
|
assert.Equal(t, ErrNotConfigured, err)
|
||||||
|
|
||||||
|
tok, err := l.Lookup("y")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "bcd", tok.AccessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteKeyNotExist(t *testing.T) {
|
||||||
|
c := &FileTokenCache{
|
||||||
|
Tokens: map[string]*oauth2.Token{},
|
||||||
|
}
|
||||||
|
err := c.Delete("x")
|
||||||
|
assert.Equal(t, ErrNotConfigured, err)
|
||||||
|
|
||||||
|
_, err = c.Lookup("x")
|
||||||
|
assert.Equal(t, ErrNotConfigured, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrite(t *testing.T) {
|
||||||
|
tempFile := filepath.Join(t.TempDir(), "token-cache.json")
|
||||||
|
|
||||||
|
tokenMap := map[string]*oauth2.Token{}
|
||||||
|
token := &oauth2.Token{
|
||||||
|
AccessToken: "some-access-token",
|
||||||
|
}
|
||||||
|
tokenMap["test"] = token
|
||||||
|
|
||||||
|
cache := &FileTokenCache{
|
||||||
|
fileLocation: tempFile,
|
||||||
|
Tokens: tokenMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cache.write()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
content, err := os.ReadFile(tempFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expected, _ := json.MarshalIndent(&cache, "", " ")
|
||||||
|
assert.Equal(t, content, expected)
|
||||||
|
}
|
||||||
|
|
|
@ -23,4 +23,14 @@ func (i *InMemoryTokenCache) Store(key string, t *oauth2.Token) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Delete implements TokenCache.
|
||||||
|
func (i *InMemoryTokenCache) Delete(key string) error {
|
||||||
|
_, ok := i.Tokens[key]
|
||||||
|
if !ok {
|
||||||
|
return ErrNotConfigured
|
||||||
|
}
|
||||||
|
delete(i.Tokens, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var _ TokenCache = (*InMemoryTokenCache)(nil)
|
var _ TokenCache = (*InMemoryTokenCache)(nil)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -42,3 +43,40 @@ func TestInMemoryCacheStore(t *testing.T) {
|
||||||
assert.Equal(t, res, token)
|
assert.Equal(t, res, token)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInMemoryDeleteKey(t *testing.T) {
|
||||||
|
c := &InMemoryTokenCache{
|
||||||
|
Tokens: map[string]*oauth2.Token{},
|
||||||
|
}
|
||||||
|
err := c.Store("x", &oauth2.Token{
|
||||||
|
AccessToken: "abc",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = c.Store("y", &oauth2.Token{
|
||||||
|
AccessToken: "bcd",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = c.Delete("x")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, len(c.Tokens))
|
||||||
|
|
||||||
|
_, err = c.Lookup("x")
|
||||||
|
assert.Equal(t, ErrNotConfigured, err)
|
||||||
|
|
||||||
|
tok, err := c.Lookup("y")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "bcd", tok.AccessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryDeleteKeyNotExist(t *testing.T) {
|
||||||
|
c := &InMemoryTokenCache{
|
||||||
|
Tokens: map[string]*oauth2.Token{},
|
||||||
|
}
|
||||||
|
err := c.Delete("x")
|
||||||
|
assert.Equal(t, ErrNotConfigured, err)
|
||||||
|
|
||||||
|
_, err = c.Lookup("x")
|
||||||
|
assert.Equal(t, ErrNotConfigured, err)
|
||||||
|
}
|
||||||
|
|
|
@ -144,6 +144,18 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *PersistentAuth) ClearToken(ctx context.Context) error {
|
||||||
|
if a.Host == "" && a.AccountID == "" {
|
||||||
|
return ErrFetchCredentials
|
||||||
|
}
|
||||||
|
if a.cache == nil {
|
||||||
|
a.cache = cache.GetTokenCache(ctx)
|
||||||
|
}
|
||||||
|
// lookup token identified by host (and possibly the account id)
|
||||||
|
key := a.key()
|
||||||
|
return a.cache.Delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
// This function cleans up the host URL by only retaining the scheme and the host.
|
// This function cleans up the host URL by only retaining the scheme and the host.
|
||||||
// This function thus removes any path, query arguments, or fragments from the URL.
|
// This function thus removes any path, query arguments, or fragments from the URL.
|
||||||
func (a *PersistentAuth) cleanHost() {
|
func (a *PersistentAuth) cleanHost() {
|
||||||
|
|
|
@ -55,6 +55,7 @@ func TestOidcForWorkspace(t *testing.T) {
|
||||||
type tokenCacheMock struct {
|
type tokenCacheMock struct {
|
||||||
store func(key string, t *oauth2.Token) error
|
store func(key string, t *oauth2.Token) error
|
||||||
lookup func(key string) (*oauth2.Token, error)
|
lookup func(key string) (*oauth2.Token, error)
|
||||||
|
delete func(key string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *tokenCacheMock) Store(key string, t *oauth2.Token) error {
|
func (m *tokenCacheMock) Store(key string, t *oauth2.Token) error {
|
||||||
|
@ -71,6 +72,13 @@ func (m *tokenCacheMock) Lookup(key string) (*oauth2.Token, error) {
|
||||||
return m.lookup(key)
|
return m.lookup(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *tokenCacheMock) Delete(key string) error {
|
||||||
|
if m.delete == nil {
|
||||||
|
panic("no deleteKey mock")
|
||||||
|
}
|
||||||
|
return m.delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
func TestLoad(t *testing.T) {
|
func TestLoad(t *testing.T) {
|
||||||
p := &PersistentAuth{
|
p := &PersistentAuth{
|
||||||
Host: "abc",
|
Host: "abc",
|
||||||
|
@ -229,6 +237,52 @@ func TestChallengeFailed(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClearToken(t *testing.T) {
|
||||||
|
p := &PersistentAuth{
|
||||||
|
Host: "abc",
|
||||||
|
AccountID: "xyz",
|
||||||
|
cache: &tokenCacheMock{
|
||||||
|
lookup: func(key string) (*oauth2.Token, error) {
|
||||||
|
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
|
||||||
|
return &oauth2.Token{}, ErrNotConfigured
|
||||||
|
},
|
||||||
|
delete: func(key string) error {
|
||||||
|
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
defer p.Close()
|
||||||
|
err := p.ClearToken(context.Background())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
key := p.key()
|
||||||
|
_, err = p.cache.Lookup(key)
|
||||||
|
assert.Equal(t, ErrNotConfigured, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearTokenNotExist(t *testing.T) {
|
||||||
|
p := &PersistentAuth{
|
||||||
|
Host: "abc",
|
||||||
|
AccountID: "xyz",
|
||||||
|
cache: &tokenCacheMock{
|
||||||
|
lookup: func(key string) (*oauth2.Token, error) {
|
||||||
|
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
|
||||||
|
return &oauth2.Token{}, ErrNotConfigured
|
||||||
|
},
|
||||||
|
delete: func(key string) error {
|
||||||
|
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
|
||||||
|
return ErrNotConfigured
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
defer p.Close()
|
||||||
|
err := p.ClearToken(context.Background())
|
||||||
|
assert.Equal(t, ErrNotConfigured, err)
|
||||||
|
key := p.key()
|
||||||
|
_, err = p.cache.Lookup(key)
|
||||||
|
assert.Equal(t, ErrNotConfigured, err)
|
||||||
|
}
|
||||||
|
|
||||||
func TestPersistentAuthCleanHost(t *testing.T) {
|
func TestPersistentAuthCleanHost(t *testing.T) {
|
||||||
for _, tcases := range []struct {
|
for _, tcases := range []struct {
|
||||||
in string
|
in string
|
||||||
|
|
Loading…
Reference in New Issue