diff --git a/bundle/config/workspace.go b/bundle/config/workspace.go index f29d7c56..16a70afb 100644 --- a/bundle/config/workspace.go +++ b/bundle/config/workspace.go @@ -79,7 +79,7 @@ func (s User) MarshalJSON() ([]byte, error) { } func (w *Workspace) Client() (*databricks.WorkspaceClient, error) { - cfg := databricks.Config{ + cfg := config.Config{ // Generic Host: w.Host, Profile: w.Profile, @@ -114,14 +114,23 @@ func (w *Workspace) Client() (*databricks.WorkspaceClient, error) { } } - if w.Profile != "" && w.Host != "" { + // Resolve the configuration. This is done by [databricks.NewWorkspaceClient] as well, but here + // we need to verify that a profile, if loaded, matches the host configured in the bundle. + err := cfg.EnsureResolved() + if err != nil { + return nil, err + } + + // Now that the configuration is resolved, we can verify that the host in the bundle configuration + // is identical to the host associated with the selected profile. + if w.Host != "" && w.Profile != "" { err := databrickscfg.ValidateConfigAndProfileHost(&cfg, w.Profile) if err != nil { return nil, err } } - return databricks.NewWorkspaceClient(&cfg) + return databricks.NewWorkspaceClient((*databricks.Config)(&cfg)) } func init() { diff --git a/bundle/config/workspace_test.go b/bundle/config/workspace_test.go new file mode 100644 index 00000000..3ef96325 --- /dev/null +++ b/bundle/config/workspace_test.go @@ -0,0 +1,144 @@ +package config + +import ( + "context" + "io/fs" + "path/filepath" + "runtime" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/databricks/cli/libs/databrickscfg" + "github.com/databricks/databricks-sdk-go/config" + "github.com/stretchr/testify/assert" +) + +func setupWorkspaceTest(t *testing.T) string { + testutil.CleanupEnvironment(t) + + home := t.TempDir() + t.Setenv("HOME", home) + if runtime.GOOS == "windows" { + t.Setenv("USERPROFILE", home) + } + + return home +} + +func TestWorkspaceResolveProfileFromHost(t *testing.T) { + // If only a workspace host is specified, try to find a profile that uses + // the same workspace host (unambiguously). + w := Workspace{ + Host: "https://abc.cloud.databricks.com", + } + + t.Run("no config file", func(t *testing.T) { + setupWorkspaceTest(t) + _, err := w.Client() + assert.NoError(t, err) + }) + + t.Run("default config file", func(t *testing.T) { + setupWorkspaceTest(t) + + // This works if there is a config file with a matching profile. + databrickscfg.SaveToProfile(context.Background(), &config.Config{ + Profile: "default", + Host: "https://abc.cloud.databricks.com", + Token: "123", + }) + + client, err := w.Client() + assert.NoError(t, err) + assert.Equal(t, "default", client.Config.Profile) + }) + + t.Run("custom config file", func(t *testing.T) { + home := setupWorkspaceTest(t) + + // This works if there is a config file with a matching profile. + databrickscfg.SaveToProfile(context.Background(), &config.Config{ + ConfigFile: filepath.Join(home, "customcfg"), + Profile: "custom", + Host: "https://abc.cloud.databricks.com", + Token: "123", + }) + + t.Setenv("DATABRICKS_CONFIG_FILE", filepath.Join(home, "customcfg")) + client, err := w.Client() + assert.NoError(t, err) + assert.Equal(t, "custom", client.Config.Profile) + }) +} + +func TestWorkspaceVerifyProfileForHost(t *testing.T) { + // If both a workspace host and a profile are specified, + // verify that the host configured in the profile matches + // the host configured in the bundle configuration. + w := Workspace{ + Host: "https://abc.cloud.databricks.com", + Profile: "abc", + } + + t.Run("no config file", func(t *testing.T) { + setupWorkspaceTest(t) + _, err := w.Client() + assert.ErrorIs(t, err, fs.ErrNotExist) + }) + + t.Run("default config file with match", func(t *testing.T) { + setupWorkspaceTest(t) + + // This works if there is a config file with a matching profile. + databrickscfg.SaveToProfile(context.Background(), &config.Config{ + Profile: "abc", + Host: "https://abc.cloud.databricks.com", + }) + + _, err := w.Client() + assert.NoError(t, err) + }) + + t.Run("default config file with mismatch", func(t *testing.T) { + setupWorkspaceTest(t) + + // This works if there is a config file with a matching profile. + databrickscfg.SaveToProfile(context.Background(), &config.Config{ + Profile: "abc", + Host: "https://def.cloud.databricks.com", + }) + + _, err := w.Client() + assert.ErrorContains(t, err, "config host mismatch") + }) + + t.Run("custom config file with match", func(t *testing.T) { + home := setupWorkspaceTest(t) + + // This works if there is a config file with a matching profile. + databrickscfg.SaveToProfile(context.Background(), &config.Config{ + ConfigFile: filepath.Join(home, "customcfg"), + Profile: "abc", + Host: "https://abc.cloud.databricks.com", + }) + + t.Setenv("DATABRICKS_CONFIG_FILE", filepath.Join(home, "customcfg")) + _, err := w.Client() + assert.NoError(t, err) + }) + + t.Run("custom config file with mismatch", func(t *testing.T) { + home := setupWorkspaceTest(t) + + // This works if there is a config file with a matching profile. + databrickscfg.SaveToProfile(context.Background(), &config.Config{ + ConfigFile: filepath.Join(home, "customcfg"), + Profile: "abc", + Host: "https://def.cloud.databricks.com", + }) + + t.Setenv("DATABRICKS_CONFIG_FILE", filepath.Join(home, "customcfg")) + _, err := w.Client() + assert.ErrorContains(t, err, "config host mismatch") + }) +} diff --git a/cmd/root/bundle_test.go b/cmd/root/bundle_test.go index 3f9641b7..d7bae2d1 100644 --- a/cmd/root/bundle_test.go +++ b/cmd/root/bundle_test.go @@ -83,7 +83,7 @@ func TestBundleConfigureWithNonExistentProfileFlag(t *testing.T) { cmd.Flag("profile").Value.Set("NOEXIST") b := setup(t, cmd, "https://x.com") - assert.PanicsWithError(t, "no matching config profiles found", func() { + assert.Panics(t, func() { b.WorkspaceClient() }) } diff --git a/libs/databrickscfg/loader.go b/libs/databrickscfg/loader.go index 05698eb4..a7985390 100644 --- a/libs/databrickscfg/loader.go +++ b/libs/databrickscfg/loader.go @@ -103,6 +103,7 @@ func (l profileFromHostLoader) Configure(cfg *config.Config) error { return fmt.Errorf("%s %s profile: %w", configFile.Path(), match.Name(), err) } + cfg.Profile = match.Name() return nil } diff --git a/libs/databrickscfg/loader_test.go b/libs/databrickscfg/loader_test.go index 5fa7f7dd..0677687f 100644 --- a/libs/databrickscfg/loader_test.go +++ b/libs/databrickscfg/loader_test.go @@ -59,7 +59,7 @@ func TestLoaderErrorsOnInvalidFile(t *testing.T) { assert.ErrorContains(t, err, "unclosed section: ") } -func TestLoaderSkipssNoMatchingHost(t *testing.T) { +func TestLoaderSkipsNoMatchingHost(t *testing.T) { cfg := config.Config{ Loaders: []config.Loader{ ResolveProfileFromHost, @@ -73,20 +73,6 @@ func TestLoaderSkipssNoMatchingHost(t *testing.T) { assert.Empty(t, cfg.Token) } -func TestLoaderConfiguresMatchingHost(t *testing.T) { - cfg := config.Config{ - Loaders: []config.Loader{ - ResolveProfileFromHost, - }, - ConfigFile: "testdata/databrickscfg", - Host: "https://default/?foo=bar", - } - - err := cfg.EnsureResolved() - assert.NoError(t, err) - assert.Equal(t, "default", cfg.Token) -} - func TestLoaderMatchingHost(t *testing.T) { cfg := config.Config{ Loaders: []config.Loader{ @@ -99,6 +85,7 @@ func TestLoaderMatchingHost(t *testing.T) { err := cfg.EnsureResolved() assert.NoError(t, err) assert.Equal(t, "default", cfg.Token) + assert.Equal(t, "DEFAULT", cfg.Profile) } func TestLoaderMatchingHostWithQuery(t *testing.T) { @@ -113,6 +100,7 @@ func TestLoaderMatchingHostWithQuery(t *testing.T) { err := cfg.EnsureResolved() assert.NoError(t, err) assert.Equal(t, "query", cfg.Token) + assert.Equal(t, "query", cfg.Profile) } func TestLoaderErrorsOnMultipleMatches(t *testing.T) { diff --git a/libs/databrickscfg/ops.go b/libs/databrickscfg/ops.go index c2d6e9fa..90795afd 100644 --- a/libs/databrickscfg/ops.go +++ b/libs/databrickscfg/ops.go @@ -7,7 +7,6 @@ import ( "strings" "github.com/databricks/cli/libs/log" - "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/config" "gopkg.in/ini.v1" ) @@ -130,17 +129,17 @@ func SaveToProfile(ctx context.Context, cfg *config.Config) error { return configFile.SaveTo(configFile.Path()) } -func ValidateConfigAndProfileHost(cfg *databricks.Config, profile string) error { +func ValidateConfigAndProfileHost(cfg *config.Config, profile string) error { configFile, err := config.LoadFile(cfg.ConfigFile) if err != nil { return fmt.Errorf("cannot parse config file: %w", err) } + // Normalized version of the configured host. host := normalizeHost(cfg.Host) match, err := findMatchingProfile(configFile, func(s *ini.Section) bool { return profile == s.Name() }) - if err != nil { return err }