mirror of https://github.com/databricks/cli.git
Provide a flag on auth commands that allows binding the OAUTH redirect address to a port rather than an address, allowing binding to public IP addresses rather than just localhost
This commit is contained in:
parent
fd8dbff631
commit
daf86ea0a5
|
@ -17,7 +17,7 @@ func New() *cobra.Command {
|
|||
var perisistentAuth auth.PersistentAuth
|
||||
cmd.PersistentFlags().StringVar(&perisistentAuth.Host, "host", perisistentAuth.Host, "Databricks Host")
|
||||
cmd.PersistentFlags().StringVar(&perisistentAuth.AccountID, "account-id", perisistentAuth.AccountID, "Databricks Account ID")
|
||||
|
||||
cmd.PersistentFlags().BoolVar(&perisistentAuth.BindPublicAddress, "bind-public", perisistentAuth.BindPublicAddress, "Allow OAUTH redirect to bind to all local IP addresses including public addresses (NOTE: this is less secure)")
|
||||
cmd.AddCommand(newEnvCommand())
|
||||
cmd.AddCommand(newLoginCommand(&perisistentAuth))
|
||||
cmd.AddCommand(newProfilesCommand())
|
||||
|
|
|
@ -24,8 +24,9 @@ const (
|
|||
// these values are predefined by Databricks as a public client
|
||||
// and is specific to this application only. Using these values
|
||||
// for other applications is not allowed.
|
||||
appClientID = "databricks-cli"
|
||||
appRedirectAddr = "localhost:8020"
|
||||
appClientID = "databricks-cli"
|
||||
appRedirectPort = ":8020"
|
||||
defaultAppRedirectAddr = "localhost" + appRedirectPort
|
||||
|
||||
// maximum amount of time to acquire listener on appRedirectAddr
|
||||
DefaultTimeout = 45 * time.Second
|
||||
|
@ -41,10 +42,12 @@ type PersistentAuth struct {
|
|||
Host string
|
||||
AccountID string
|
||||
|
||||
http *httpclient.ApiClient
|
||||
cache tokenCache
|
||||
ln net.Listener
|
||||
browser func(string) error
|
||||
http *httpclient.ApiClient
|
||||
cache tokenCache
|
||||
ln net.Listener
|
||||
browser func(string) error
|
||||
BindPublicAddress bool
|
||||
BoundAddress string
|
||||
}
|
||||
|
||||
type tokenCache interface {
|
||||
|
@ -146,13 +149,24 @@ func (a *PersistentAuth) init(ctx context.Context) error {
|
|||
if a.browser == nil {
|
||||
a.browser = browser.OpenURL
|
||||
}
|
||||
|
||||
// For various use cases need to bind to the port rather than an address, otherwise
|
||||
// we only bind to a single IP which may or may not be correct. This is controlled
|
||||
// by the BindPublicAddress flag. By default we will just used the defaultAppRedirectAddr
|
||||
// which is localhost. See: https://pkg.go.dev/net#ListenIP for issues with this.
|
||||
if a.BindPublicAddress {
|
||||
a.BoundAddress = appRedirectPort
|
||||
} else {
|
||||
a.BoundAddress = defaultAppRedirectAddr
|
||||
}
|
||||
|
||||
// try acquire listener, which we also use as a machine-local
|
||||
// exclusive lock to prevent token cache corruption in the scope
|
||||
// of developer machine, where this command runs.
|
||||
listener, err := retries.Poll(ctx, DefaultTimeout,
|
||||
func() (*net.Listener, *retries.Err) {
|
||||
var lc net.ListenConfig
|
||||
l, err := lc.Listen(ctx, "tcp", appRedirectAddr)
|
||||
l, err := lc.Listen(ctx, "tcp", a.BoundAddress)
|
||||
if err != nil {
|
||||
return nil, retries.Continue(err)
|
||||
}
|
||||
|
@ -213,7 +227,7 @@ func (a *PersistentAuth) oauth2Config(ctx context.Context) (*oauth2.Config, erro
|
|||
TokenURL: endpoints.TokenEndpoint,
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr),
|
||||
RedirectURL: fmt.Sprintf("http://%s", a.BoundAddress),
|
||||
Scopes: scopes,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -181,7 +181,7 @@ func TestChallenge(t *testing.T) {
|
|||
}()
|
||||
|
||||
state := <-browserOpened
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s?code=__THIS__&state=%s", appRedirectAddr, state))
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s?code=__THIS__&state=%s", defaultAppRedirectAddr, state))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
|
@ -220,7 +220,7 @@ func TestChallengeFailed(t *testing.T) {
|
|||
<-browserOpened
|
||||
resp, err := http.Get(fmt.Sprintf(
|
||||
"http://%s?error=access_denied&error_description=Policy%%20evaluation%%20failed%%20for%%20this%%20request",
|
||||
appRedirectAddr))
|
||||
defaultAppRedirectAddr))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 400, resp.StatusCode)
|
||||
|
||||
|
@ -228,3 +228,45 @@ func TestChallengeFailed(t *testing.T) {
|
|||
assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBindPublicAddress(t *testing.T) {
|
||||
p := &PersistentAuth{
|
||||
Host: "abc",
|
||||
AccountID: "xyz",
|
||||
cache: &tokenCacheMock{
|
||||
lookup: func(key string) (*oauth2.Token, error) {
|
||||
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
|
||||
return &oauth2.Token{
|
||||
AccessToken: "bcd",
|
||||
Expiry: time.Now().Add(1 * time.Minute),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
BindPublicAddress: true,
|
||||
}
|
||||
defer p.Close()
|
||||
_, err := p.Load(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "[::]:8020", p.ln.Addr().String())
|
||||
}
|
||||
|
||||
func TestBindPrivateAddressOnly(t *testing.T) {
|
||||
p := &PersistentAuth{
|
||||
Host: "abc",
|
||||
AccountID: "xyz",
|
||||
cache: &tokenCacheMock{
|
||||
lookup: func(key string) (*oauth2.Token, error) {
|
||||
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
|
||||
return &oauth2.Token{
|
||||
AccessToken: "bcd",
|
||||
Expiry: time.Now().Add(1 * time.Minute),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
BindPublicAddress: false,
|
||||
}
|
||||
defer p.Close()
|
||||
_, err := p.Load(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "127.0.0.1:8020", p.ln.Addr().String())
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue