mirror of https://github.com/databricks/cli.git
Provide a flag on auth commands that allows binding the OAUTH redirect address to a port rather than an address, allowing binding to public IP addresses rather than just localhost
This commit is contained in:
parent
fd8dbff631
commit
daf86ea0a5
|
@ -17,7 +17,7 @@ func New() *cobra.Command {
|
||||||
var perisistentAuth auth.PersistentAuth
|
var perisistentAuth auth.PersistentAuth
|
||||||
cmd.PersistentFlags().StringVar(&perisistentAuth.Host, "host", perisistentAuth.Host, "Databricks Host")
|
cmd.PersistentFlags().StringVar(&perisistentAuth.Host, "host", perisistentAuth.Host, "Databricks Host")
|
||||||
cmd.PersistentFlags().StringVar(&perisistentAuth.AccountID, "account-id", perisistentAuth.AccountID, "Databricks Account ID")
|
cmd.PersistentFlags().StringVar(&perisistentAuth.AccountID, "account-id", perisistentAuth.AccountID, "Databricks Account ID")
|
||||||
|
cmd.PersistentFlags().BoolVar(&perisistentAuth.BindPublicAddress, "bind-public", perisistentAuth.BindPublicAddress, "Allow OAUTH redirect to bind to all local IP addresses including public addresses (NOTE: this is less secure)")
|
||||||
cmd.AddCommand(newEnvCommand())
|
cmd.AddCommand(newEnvCommand())
|
||||||
cmd.AddCommand(newLoginCommand(&perisistentAuth))
|
cmd.AddCommand(newLoginCommand(&perisistentAuth))
|
||||||
cmd.AddCommand(newProfilesCommand())
|
cmd.AddCommand(newProfilesCommand())
|
||||||
|
|
|
@ -24,8 +24,9 @@ const (
|
||||||
// these values are predefined by Databricks as a public client
|
// these values are predefined by Databricks as a public client
|
||||||
// and is specific to this application only. Using these values
|
// and is specific to this application only. Using these values
|
||||||
// for other applications is not allowed.
|
// for other applications is not allowed.
|
||||||
appClientID = "databricks-cli"
|
appClientID = "databricks-cli"
|
||||||
appRedirectAddr = "localhost:8020"
|
appRedirectPort = ":8020"
|
||||||
|
defaultAppRedirectAddr = "localhost" + appRedirectPort
|
||||||
|
|
||||||
// maximum amount of time to acquire listener on appRedirectAddr
|
// maximum amount of time to acquire listener on appRedirectAddr
|
||||||
DefaultTimeout = 45 * time.Second
|
DefaultTimeout = 45 * time.Second
|
||||||
|
@ -41,10 +42,12 @@ type PersistentAuth struct {
|
||||||
Host string
|
Host string
|
||||||
AccountID string
|
AccountID string
|
||||||
|
|
||||||
http *httpclient.ApiClient
|
http *httpclient.ApiClient
|
||||||
cache tokenCache
|
cache tokenCache
|
||||||
ln net.Listener
|
ln net.Listener
|
||||||
browser func(string) error
|
browser func(string) error
|
||||||
|
BindPublicAddress bool
|
||||||
|
BoundAddress string
|
||||||
}
|
}
|
||||||
|
|
||||||
type tokenCache interface {
|
type tokenCache interface {
|
||||||
|
@ -146,13 +149,24 @@ func (a *PersistentAuth) init(ctx context.Context) error {
|
||||||
if a.browser == nil {
|
if a.browser == nil {
|
||||||
a.browser = browser.OpenURL
|
a.browser = browser.OpenURL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For various use cases need to bind to the port rather than an address, otherwise
|
||||||
|
// we only bind to a single IP which may or may not be correct. This is controlled
|
||||||
|
// by the BindPublicAddress flag. By default we will just used the defaultAppRedirectAddr
|
||||||
|
// which is localhost. See: https://pkg.go.dev/net#ListenIP for issues with this.
|
||||||
|
if a.BindPublicAddress {
|
||||||
|
a.BoundAddress = appRedirectPort
|
||||||
|
} else {
|
||||||
|
a.BoundAddress = defaultAppRedirectAddr
|
||||||
|
}
|
||||||
|
|
||||||
// try acquire listener, which we also use as a machine-local
|
// try acquire listener, which we also use as a machine-local
|
||||||
// exclusive lock to prevent token cache corruption in the scope
|
// exclusive lock to prevent token cache corruption in the scope
|
||||||
// of developer machine, where this command runs.
|
// of developer machine, where this command runs.
|
||||||
listener, err := retries.Poll(ctx, DefaultTimeout,
|
listener, err := retries.Poll(ctx, DefaultTimeout,
|
||||||
func() (*net.Listener, *retries.Err) {
|
func() (*net.Listener, *retries.Err) {
|
||||||
var lc net.ListenConfig
|
var lc net.ListenConfig
|
||||||
l, err := lc.Listen(ctx, "tcp", appRedirectAddr)
|
l, err := lc.Listen(ctx, "tcp", a.BoundAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, retries.Continue(err)
|
return nil, retries.Continue(err)
|
||||||
}
|
}
|
||||||
|
@ -213,7 +227,7 @@ func (a *PersistentAuth) oauth2Config(ctx context.Context) (*oauth2.Config, erro
|
||||||
TokenURL: endpoints.TokenEndpoint,
|
TokenURL: endpoints.TokenEndpoint,
|
||||||
AuthStyle: oauth2.AuthStyleInParams,
|
AuthStyle: oauth2.AuthStyleInParams,
|
||||||
},
|
},
|
||||||
RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr),
|
RedirectURL: fmt.Sprintf("http://%s", a.BoundAddress),
|
||||||
Scopes: scopes,
|
Scopes: scopes,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -181,7 +181,7 @@ func TestChallenge(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
state := <-browserOpened
|
state := <-browserOpened
|
||||||
resp, err := http.Get(fmt.Sprintf("http://%s?code=__THIS__&state=%s", appRedirectAddr, state))
|
resp, err := http.Get(fmt.Sprintf("http://%s?code=__THIS__&state=%s", defaultAppRedirectAddr, state))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 200, resp.StatusCode)
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
@ -220,7 +220,7 @@ func TestChallengeFailed(t *testing.T) {
|
||||||
<-browserOpened
|
<-browserOpened
|
||||||
resp, err := http.Get(fmt.Sprintf(
|
resp, err := http.Get(fmt.Sprintf(
|
||||||
"http://%s?error=access_denied&error_description=Policy%%20evaluation%%20failed%%20for%%20this%%20request",
|
"http://%s?error=access_denied&error_description=Policy%%20evaluation%%20failed%%20for%%20this%%20request",
|
||||||
appRedirectAddr))
|
defaultAppRedirectAddr))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 400, resp.StatusCode)
|
assert.Equal(t, 400, resp.StatusCode)
|
||||||
|
|
||||||
|
@ -228,3 +228,45 @@ func TestChallengeFailed(t *testing.T) {
|
||||||
assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request")
|
assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBindPublicAddress(t *testing.T) {
|
||||||
|
p := &PersistentAuth{
|
||||||
|
Host: "abc",
|
||||||
|
AccountID: "xyz",
|
||||||
|
cache: &tokenCacheMock{
|
||||||
|
lookup: func(key string) (*oauth2.Token, error) {
|
||||||
|
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
|
||||||
|
return &oauth2.Token{
|
||||||
|
AccessToken: "bcd",
|
||||||
|
Expiry: time.Now().Add(1 * time.Minute),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
BindPublicAddress: true,
|
||||||
|
}
|
||||||
|
defer p.Close()
|
||||||
|
_, err := p.Load(context.Background())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "[::]:8020", p.ln.Addr().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBindPrivateAddressOnly(t *testing.T) {
|
||||||
|
p := &PersistentAuth{
|
||||||
|
Host: "abc",
|
||||||
|
AccountID: "xyz",
|
||||||
|
cache: &tokenCacheMock{
|
||||||
|
lookup: func(key string) (*oauth2.Token, error) {
|
||||||
|
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
|
||||||
|
return &oauth2.Token{
|
||||||
|
AccessToken: "bcd",
|
||||||
|
Expiry: time.Now().Add(1 * time.Minute),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
BindPublicAddress: false,
|
||||||
|
}
|
||||||
|
defer p.Close()
|
||||||
|
_, err := p.Load(context.Background())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "127.0.0.1:8020", p.ln.Addr().String())
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue