Provide a flag on auth commands that allows binding the OAUTH redirect address to a port rather than an address, allowing binding to public IP addresses rather than just localhost

This commit is contained in:
Ben Phegan 2024-03-28 09:58:13 +11:00
parent fd8dbff631
commit daf86ea0a5
No known key found for this signature in database
GPG Key ID: A14E0006BF2E3E19
3 changed files with 67 additions and 11 deletions

View File

@ -17,7 +17,7 @@ func New() *cobra.Command {
var perisistentAuth auth.PersistentAuth var perisistentAuth auth.PersistentAuth
cmd.PersistentFlags().StringVar(&perisistentAuth.Host, "host", perisistentAuth.Host, "Databricks Host") cmd.PersistentFlags().StringVar(&perisistentAuth.Host, "host", perisistentAuth.Host, "Databricks Host")
cmd.PersistentFlags().StringVar(&perisistentAuth.AccountID, "account-id", perisistentAuth.AccountID, "Databricks Account ID") cmd.PersistentFlags().StringVar(&perisistentAuth.AccountID, "account-id", perisistentAuth.AccountID, "Databricks Account ID")
cmd.PersistentFlags().BoolVar(&perisistentAuth.BindPublicAddress, "bind-public", perisistentAuth.BindPublicAddress, "Allow OAUTH redirect to bind to all local IP addresses including public addresses (NOTE: this is less secure)")
cmd.AddCommand(newEnvCommand()) cmd.AddCommand(newEnvCommand())
cmd.AddCommand(newLoginCommand(&perisistentAuth)) cmd.AddCommand(newLoginCommand(&perisistentAuth))
cmd.AddCommand(newProfilesCommand()) cmd.AddCommand(newProfilesCommand())

View File

@ -25,7 +25,8 @@ const (
// and is specific to this application only. Using these values // and is specific to this application only. Using these values
// for other applications is not allowed. // for other applications is not allowed.
appClientID = "databricks-cli" appClientID = "databricks-cli"
appRedirectAddr = "localhost:8020" appRedirectPort = ":8020"
defaultAppRedirectAddr = "localhost" + appRedirectPort
// maximum amount of time to acquire listener on appRedirectAddr // maximum amount of time to acquire listener on appRedirectAddr
DefaultTimeout = 45 * time.Second DefaultTimeout = 45 * time.Second
@ -45,6 +46,8 @@ type PersistentAuth struct {
cache tokenCache cache tokenCache
ln net.Listener ln net.Listener
browser func(string) error browser func(string) error
BindPublicAddress bool
BoundAddress string
} }
type tokenCache interface { type tokenCache interface {
@ -146,13 +149,24 @@ func (a *PersistentAuth) init(ctx context.Context) error {
if a.browser == nil { if a.browser == nil {
a.browser = browser.OpenURL a.browser = browser.OpenURL
} }
// For various use cases need to bind to the port rather than an address, otherwise
// we only bind to a single IP which may or may not be correct. This is controlled
// by the BindPublicAddress flag. By default we will just used the defaultAppRedirectAddr
// which is localhost. See: https://pkg.go.dev/net#ListenIP for issues with this.
if a.BindPublicAddress {
a.BoundAddress = appRedirectPort
} else {
a.BoundAddress = defaultAppRedirectAddr
}
// try acquire listener, which we also use as a machine-local // try acquire listener, which we also use as a machine-local
// exclusive lock to prevent token cache corruption in the scope // exclusive lock to prevent token cache corruption in the scope
// of developer machine, where this command runs. // of developer machine, where this command runs.
listener, err := retries.Poll(ctx, DefaultTimeout, listener, err := retries.Poll(ctx, DefaultTimeout,
func() (*net.Listener, *retries.Err) { func() (*net.Listener, *retries.Err) {
var lc net.ListenConfig var lc net.ListenConfig
l, err := lc.Listen(ctx, "tcp", appRedirectAddr) l, err := lc.Listen(ctx, "tcp", a.BoundAddress)
if err != nil { if err != nil {
return nil, retries.Continue(err) return nil, retries.Continue(err)
} }
@ -213,7 +227,7 @@ func (a *PersistentAuth) oauth2Config(ctx context.Context) (*oauth2.Config, erro
TokenURL: endpoints.TokenEndpoint, TokenURL: endpoints.TokenEndpoint,
AuthStyle: oauth2.AuthStyleInParams, AuthStyle: oauth2.AuthStyleInParams,
}, },
RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr), RedirectURL: fmt.Sprintf("http://%s", a.BoundAddress),
Scopes: scopes, Scopes: scopes,
}, nil }, nil
} }

View File

@ -181,7 +181,7 @@ func TestChallenge(t *testing.T) {
}() }()
state := <-browserOpened state := <-browserOpened
resp, err := http.Get(fmt.Sprintf("http://%s?code=__THIS__&state=%s", appRedirectAddr, state)) resp, err := http.Get(fmt.Sprintf("http://%s?code=__THIS__&state=%s", defaultAppRedirectAddr, state))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode) assert.Equal(t, 200, resp.StatusCode)
@ -220,7 +220,7 @@ func TestChallengeFailed(t *testing.T) {
<-browserOpened <-browserOpened
resp, err := http.Get(fmt.Sprintf( resp, err := http.Get(fmt.Sprintf(
"http://%s?error=access_denied&error_description=Policy%%20evaluation%%20failed%%20for%%20this%%20request", "http://%s?error=access_denied&error_description=Policy%%20evaluation%%20failed%%20for%%20this%%20request",
appRedirectAddr)) defaultAppRedirectAddr))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 400, resp.StatusCode) assert.Equal(t, 400, resp.StatusCode)
@ -228,3 +228,45 @@ func TestChallengeFailed(t *testing.T) {
assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request") assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request")
}) })
} }
func TestBindPublicAddress(t *testing.T) {
p := &PersistentAuth{
Host: "abc",
AccountID: "xyz",
cache: &tokenCacheMock{
lookup: func(key string) (*oauth2.Token, error) {
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
return &oauth2.Token{
AccessToken: "bcd",
Expiry: time.Now().Add(1 * time.Minute),
}, nil
},
},
BindPublicAddress: true,
}
defer p.Close()
_, err := p.Load(context.Background())
assert.NoError(t, err)
assert.Equal(t, "[::]:8020", p.ln.Addr().String())
}
func TestBindPrivateAddressOnly(t *testing.T) {
p := &PersistentAuth{
Host: "abc",
AccountID: "xyz",
cache: &tokenCacheMock{
lookup: func(key string) (*oauth2.Token, error) {
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
return &oauth2.Token{
AccessToken: "bcd",
Expiry: time.Now().Add(1 * time.Minute),
}, nil
},
},
BindPublicAddress: false,
}
defer p.Close()
_, err := p.Load(context.Background())
assert.NoError(t, err)
assert.Equal(t, "127.0.0.1:8020", p.ln.Addr().String())
}