diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index e0c7c7c5..aac593bf 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -17,7 +17,7 @@ func New() *cobra.Command { var perisistentAuth auth.PersistentAuth cmd.PersistentFlags().StringVar(&perisistentAuth.Host, "host", perisistentAuth.Host, "Databricks Host") cmd.PersistentFlags().StringVar(&perisistentAuth.AccountID, "account-id", perisistentAuth.AccountID, "Databricks Account ID") - + cmd.PersistentFlags().BoolVar(&perisistentAuth.BindPublicAddress, "bind-public", perisistentAuth.BindPublicAddress, "Allow OAUTH redirect to bind to all local IP addresses including public addresses (NOTE: this is less secure)") cmd.AddCommand(newEnvCommand()) cmd.AddCommand(newLoginCommand(&perisistentAuth)) cmd.AddCommand(newProfilesCommand()) diff --git a/libs/auth/oauth.go b/libs/auth/oauth.go index 4ce0d4de..9b9e6218 100644 --- a/libs/auth/oauth.go +++ b/libs/auth/oauth.go @@ -24,8 +24,9 @@ const ( // these values are predefined by Databricks as a public client // and is specific to this application only. Using these values // for other applications is not allowed. - appClientID = "databricks-cli" - appRedirectAddr = "localhost:8020" + appClientID = "databricks-cli" + appRedirectPort = ":8020" + defaultAppRedirectAddr = "localhost" + appRedirectPort // maximum amount of time to acquire listener on appRedirectAddr DefaultTimeout = 45 * time.Second @@ -41,10 +42,12 @@ type PersistentAuth struct { Host string AccountID string - http *httpclient.ApiClient - cache tokenCache - ln net.Listener - browser func(string) error + http *httpclient.ApiClient + cache tokenCache + ln net.Listener + browser func(string) error + BindPublicAddress bool + BoundAddress string } type tokenCache interface { @@ -146,13 +149,24 @@ func (a *PersistentAuth) init(ctx context.Context) error { if a.browser == nil { a.browser = browser.OpenURL } + + // For various use cases need to bind to the port rather than an address, otherwise + // we only bind to a single IP which may or may not be correct. This is controlled + // by the BindPublicAddress flag. By default we will just used the defaultAppRedirectAddr + // which is localhost. See: https://pkg.go.dev/net#ListenIP for issues with this. + if a.BindPublicAddress { + a.BoundAddress = appRedirectPort + } else { + a.BoundAddress = defaultAppRedirectAddr + } + // try acquire listener, which we also use as a machine-local // exclusive lock to prevent token cache corruption in the scope // of developer machine, where this command runs. listener, err := retries.Poll(ctx, DefaultTimeout, func() (*net.Listener, *retries.Err) { var lc net.ListenConfig - l, err := lc.Listen(ctx, "tcp", appRedirectAddr) + l, err := lc.Listen(ctx, "tcp", a.BoundAddress) if err != nil { return nil, retries.Continue(err) } @@ -213,7 +227,7 @@ func (a *PersistentAuth) oauth2Config(ctx context.Context) (*oauth2.Config, erro TokenURL: endpoints.TokenEndpoint, AuthStyle: oauth2.AuthStyleInParams, }, - RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr), + RedirectURL: fmt.Sprintf("http://%s", a.BoundAddress), Scopes: scopes, }, nil } diff --git a/libs/auth/oauth_test.go b/libs/auth/oauth_test.go index ea6a8061..f30374b4 100644 --- a/libs/auth/oauth_test.go +++ b/libs/auth/oauth_test.go @@ -181,7 +181,7 @@ func TestChallenge(t *testing.T) { }() state := <-browserOpened - resp, err := http.Get(fmt.Sprintf("http://%s?code=__THIS__&state=%s", appRedirectAddr, state)) + resp, err := http.Get(fmt.Sprintf("http://%s?code=__THIS__&state=%s", defaultAppRedirectAddr, state)) assert.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) @@ -220,7 +220,7 @@ func TestChallengeFailed(t *testing.T) { <-browserOpened resp, err := http.Get(fmt.Sprintf( "http://%s?error=access_denied&error_description=Policy%%20evaluation%%20failed%%20for%%20this%%20request", - appRedirectAddr)) + defaultAppRedirectAddr)) assert.NoError(t, err) assert.Equal(t, 400, resp.StatusCode) @@ -228,3 +228,45 @@ func TestChallengeFailed(t *testing.T) { assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request") }) } + +func TestBindPublicAddress(t *testing.T) { + p := &PersistentAuth{ + Host: "abc", + AccountID: "xyz", + cache: &tokenCacheMock{ + lookup: func(key string) (*oauth2.Token, error) { + assert.Equal(t, "https://abc/oidc/accounts/xyz", key) + return &oauth2.Token{ + AccessToken: "bcd", + Expiry: time.Now().Add(1 * time.Minute), + }, nil + }, + }, + BindPublicAddress: true, + } + defer p.Close() + _, err := p.Load(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "[::]:8020", p.ln.Addr().String()) +} + +func TestBindPrivateAddressOnly(t *testing.T) { + p := &PersistentAuth{ + Host: "abc", + AccountID: "xyz", + cache: &tokenCacheMock{ + lookup: func(key string) (*oauth2.Token, error) { + assert.Equal(t, "https://abc/oidc/accounts/xyz", key) + return &oauth2.Token{ + AccessToken: "bcd", + Expiry: time.Now().Add(1 * time.Minute), + }, nil + }, + }, + BindPublicAddress: false, + } + defer p.Close() + _, err := p.Load(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "127.0.0.1:8020", p.ln.Addr().String()) +}