Compare commits

...

4 Commits

Author SHA1 Message Date
Miles Yucht 37f5e80ba5
Merge 517e1b3310 into cc112961ce 2024-10-16 13:57:15 +00:00
shreyas-goenka cc112961ce
Fix `TestAccFsMkdirWhenFileExistsAtPath` in isolated Azure environments (#1833)
## Changes
This test passes on normal `azure-prod` but started to fail on
`azure-prod-is`, which is the isolated version of azure-prod. This PR
patches the test to include the error returned from the cloud setup in
`azure-prod-is`.

## Tests
The test passes now on `azure-prod-is`.
2024-10-16 12:50:17 +00:00
Miles Yucht 517e1b3310
handle edge case 2024-09-12 10:51:36 +02:00
Miles Yucht f9675ab8ea
hackathon device code flow 2024-09-12 10:49:27 +02:00
5 changed files with 113 additions and 22 deletions

View File

@ -84,10 +84,13 @@ depends on the existing profiles you have set in your configuration file
var loginTimeout time.Duration
var configureCluster bool
var deviceCode bool
cmd.Flags().DurationVar(&loginTimeout, "timeout", defaultTimeout,
"Timeout for completing login challenge in the browser")
cmd.Flags().BoolVar(&configureCluster, "configure-cluster", false,
"Prompts to configure cluster")
cmd.Flags().BoolVar(&deviceCode, "device-code", false,
"Use device code flow for authentication")
cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
@ -120,7 +123,11 @@ depends on the existing profiles you have set in your configuration file
ctx, cancel := context.WithTimeout(ctx, loginTimeout)
defer cancel()
if deviceCode {
err = persistentAuth.DeviceCode(ctx)
} else {
err = persistentAuth.Challenge(ctx)
}
if err != nil {
return err
}

View File

@ -112,8 +112,8 @@ func TestAccFsMkdirWhenFileExistsAtPath(t *testing.T) {
// assert mkdir fails
_, _, err = RequireErrorRun(t, "fs", "mkdir", path.Join(tmpDir, "hello"))
// Different cloud providers return different errors.
regex := regexp.MustCompile(`(^|: )Path is a file: .*$|(^|: )Cannot create directory .* because .* is an existing file\.$|(^|: )mkdirs\(hadoopPath: .*, permission: rwxrwxrwx\): failed$`)
// Different cloud providers or cloud configurations return different errors.
regex := regexp.MustCompile(`(^|: )Path is a file: .*$|(^|: )Cannot create directory .* because .* is an existing file\.$|(^|: )mkdirs\(hadoopPath: .*, permission: rwxrwxrwx\): failed$|(^|: )"The specified path already exists.".*$`)
assert.Regexp(t, regex, err.Error())
})

View File

@ -20,6 +20,7 @@ import (
"time"
"github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/internal/acc"
"github.com/databricks/cli/libs/flags"
"github.com/databricks/cli/cmd"
@ -591,13 +592,10 @@ func setupWsfsExtensionsFiler(t *testing.T) (filer.Filer, string) {
}
func setupDbfsFiler(t *testing.T) (filer.Filer, string) {
t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV"))
_, wt := acc.WorkspaceTest(t)
w, err := databricks.NewWorkspaceClient()
require.NoError(t, err)
tmpDir := TemporaryDbfsDir(t, w)
f, err := filer.NewDbfsClient(w, tmpDir)
tmpDir := TemporaryDbfsDir(t, wt.W)
f, err := filer.NewDbfsClient(wt.W, tmpDir)
require.NoError(t, err)
return f, path.Join("dbfs:/", tmpDir)

View File

@ -49,15 +49,23 @@ var ( // Databricks SDK API: `databricks OAuth is not` will be checked for prese
ErrOAuthNotSupported = errors.New("databricks OAuth is not supported for this host")
ErrNotConfigured = errors.New("databricks OAuth is not configured for this host")
ErrFetchCredentials = errors.New("cannot fetch credentials")
ErrDeviceCodeNotSupported = errors.New("device code flow is not supported for this host")
)
type PersistentAuth struct {
Host string
AccountID string
// The client used when making requests to Databricks OAuth endpoints.
http *httpclient.ApiClient
// A token cache for OAuth access & refresh tokens.
cache cache.TokenCache
// A listener used to receive the OAuth callback. Not used for device-code flow.
ln net.Listener
// A function to open a URL in the user's browser. Not used for device-code flow.
browser func(string) error
}
@ -113,11 +121,44 @@ func (a *PersistentAuth) ProfileName() string {
return split[0]
}
func (a *PersistentAuth) DeviceCode(ctx context.Context) error {
err := a.init(ctx)
if err != nil {
return fmt.Errorf("init: %w", err)
}
cfg, err := a.oauth2Config(ctx)
if err != nil {
return err
}
if cfg.Endpoint.DeviceAuthURL == "" {
return ErrDeviceCodeNotSupported
}
ctx = a.http.InContextForOAuth2(ctx)
deviceAuthResp, err := cfg.DeviceAuth(ctx)
if err != nil {
return fmt.Errorf("error initiating device code flow: %w", err)
}
fmt.Printf("To authenticate, please visit %s and enter the code %s\n", deviceAuthResp.VerificationURI, deviceAuthResp.UserCode)
token, err := cfg.DeviceAccessToken(ctx, deviceAuthResp)
if err != nil {
return fmt.Errorf("error retrieving token: %w", err)
}
err = a.cache.Store(a.key(), token)
if err != nil {
return fmt.Errorf("store: %w", err)
}
return nil
}
func (a *PersistentAuth) Challenge(ctx context.Context) error {
err := a.init(ctx)
if err != nil {
return fmt.Errorf("init: %w", err)
}
err = a.initU2M(ctx)
if err != nil {
return fmt.Errorf("init: %w", err)
}
cfg, err := a.oauth2Config(ctx)
if err != nil {
return err
@ -143,6 +184,8 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error {
return nil
}
// init validates that the host and account id are set and initializes the http client and token cache.
// It should be called before any other method on PersistentAuth.
func (a *PersistentAuth) init(ctx context.Context) error {
if a.Host == "" && a.AccountID == "" {
return ErrFetchCredentials
@ -153,6 +196,11 @@ func (a *PersistentAuth) init(ctx context.Context) error {
if a.cache == nil {
a.cache = cache.GetTokenCache(ctx)
}
return nil
}
// initU2M initializes the listener for the user-to-machine flow. It does not need to be called for device-code flow.
func (a *PersistentAuth) initU2M(ctx context.Context) error {
if a.browser == nil {
a.browser = browser.OpenURL
}
@ -188,6 +236,7 @@ func (a *PersistentAuth) oidcEndpoints(ctx context.Context) (*oauthAuthorization
return &oauthAuthorizationServer{
AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix),
TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix),
DeviceAuthorizationEndpoint: fmt.Sprintf("%s/v1/device_authorization", prefix),
}, nil
}
var oauthEndpoints oauthAuthorizationServer
@ -221,6 +270,7 @@ func (a *PersistentAuth) oauth2Config(ctx context.Context) (*oauth2.Config, erro
Endpoint: oauth2.Endpoint{
AuthURL: endpoints.AuthorizationEndpoint,
TokenURL: endpoints.TokenEndpoint,
DeviceAuthURL: endpoints.DeviceAuthorizationEndpoint,
AuthStyle: oauth2.AuthStyleInParams,
},
RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr),
@ -262,4 +312,5 @@ func (a *PersistentAuth) randomString(size int) string {
type oauthAuthorizationServer struct {
AuthorizationEndpoint string `json:"authorization_endpoint"` // ../v1/authorize
TokenEndpoint string `json:"token_endpoint"` // ../v1/token
DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint"` // ../v1/device_authorization
}

View File

@ -228,3 +228,38 @@ func TestChallengeFailed(t *testing.T) {
assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request")
})
}
func TestDeviceCode_Account(t *testing.T) {
qa.HTTPFixtures{
{
Method: "POST",
Resource: "/oidc/accounts/xyz/v1/device_authorization",
Response: `{"device_code":"abc","user_code":"def","verification_uri":"ghi"}`,
},
{
Method: "POST",
Resource: "/oidc/accounts/xyz/v1/token",
Response: `access_token=jkl&refresh_token=mnop`,
},
}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
ctx = useInsecureOAuthHttpClientForTests(ctx)
tokenStored := false
p := &PersistentAuth{
Host: c.Config.Host,
AccountID: "xyz",
cache: &tokenCacheMock{
store: func(key string, tok *oauth2.Token) error {
assert.Equal(t, fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host), key)
assert.Equal(t, "mnop", tok.RefreshToken)
tokenStored = true
return nil
},
},
}
defer p.Close()
err := p.DeviceCode(ctx)
assert.NoError(t, err)
assert.True(t, tokenStored)
})
}