diff --git a/bundle/config/workspace.go b/bundle/config/workspace.go index fa5e0813..d6c83b77 100644 --- a/bundle/config/workspace.go +++ b/bundle/config/workspace.go @@ -3,7 +3,9 @@ package config import ( "os" + "github.com/databricks/bricks/libs/databrickscfg" "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/config" "github.com/databricks/databricks-sdk-go/service/scim" ) @@ -68,7 +70,7 @@ type Workspace struct { } func (w *Workspace) Client() (*databricks.WorkspaceClient, error) { - config := databricks.Config{ + cfg := databricks.Config{ // Generic Host: w.Host, Profile: w.Profile, @@ -85,7 +87,19 @@ func (w *Workspace) Client() (*databricks.WorkspaceClient, error) { AzureLoginAppID: w.AzureLoginAppID, } - return databricks.NewWorkspaceClient(&config) + // If only the host is configured, we try and unambiguously match it to + // a profile in the user's databrickscfg file. Override the default loaders. + cfg.Loaders = []config.Loader{ + // Defaults. + config.ConfigAttributes, + config.KnownConfigLoader{}, + + // Our loader that resolves a profile from the host alone. + // This only kicks in if the above loaders don't configure auth. + databrickscfg.ResolveProfileFromHost, + } + + return databricks.NewWorkspaceClient(&cfg) } func init() { diff --git a/libs/databrickscfg/file.go b/libs/databrickscfg/file.go new file mode 100644 index 00000000..4595c7a6 --- /dev/null +++ b/libs/databrickscfg/file.go @@ -0,0 +1,46 @@ +package databrickscfg + +import ( + "fmt" + "os" + "strings" + + "gopkg.in/ini.v1" +) + +// File represents the contents of a databrickscfg file. +type File struct { + *ini.File + + path string +} + +// Path returns the path of the loaded databrickscfg file. +func (f *File) Path() string { + return f.path +} + +// LoadFile loads the databrickscfg file at the specified path. +// The function loads ~/.databrickscfg if the specified path is an empty string. +// The function expands ~ to the user's home directory. +func LoadFile(path string) (*File, error) { + if path == "" { + path = "~/.databrickscfg" + } + + // Expand ~ to home directory. + if strings.HasPrefix(path, "~") { + homedir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("cannot find homedir: %w", err) + } + path = fmt.Sprintf("%s%s", homedir, path[1:]) + } + + iniFile, err := ini.Load(path) + if err != nil { + return nil, err + } + + return &File{iniFile, path}, err +} diff --git a/libs/databrickscfg/host.go b/libs/databrickscfg/host.go new file mode 100644 index 00000000..dc9f503c --- /dev/null +++ b/libs/databrickscfg/host.go @@ -0,0 +1,22 @@ +package databrickscfg + +import "net/url" + +// normalizeHost returns the string representation of only +// the scheme and host part of the specified host. +func normalizeHost(host string) string { + u, err := url.Parse(host) + if err != nil { + return host + } + if u.Scheme == "" || u.Host == "" { + return host + } + + normalized := &url.URL{ + Scheme: u.Scheme, + Host: u.Host, + } + + return normalized.String() +} diff --git a/libs/databrickscfg/host_test.go b/libs/databrickscfg/host_test.go new file mode 100644 index 00000000..0117aa61 --- /dev/null +++ b/libs/databrickscfg/host_test.go @@ -0,0 +1,26 @@ +package databrickscfg + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNormalizeHost(t *testing.T) { + assert.Equal(t, "invalid", normalizeHost("invalid")) + + // With port. + assert.Equal(t, "http://foo:123", normalizeHost("http://foo:123")) + + // With trailing slash. + assert.Equal(t, "http://foo", normalizeHost("http://foo/")) + + // With path. + assert.Equal(t, "http://foo", normalizeHost("http://foo/bar")) + + // With query string. + assert.Equal(t, "http://foo", normalizeHost("http://foo?bar")) + + // With anchor. + assert.Equal(t, "http://foo", normalizeHost("http://foo#bar")) +} diff --git a/libs/databrickscfg/loader.go b/libs/databrickscfg/loader.go new file mode 100644 index 00000000..bd281914 --- /dev/null +++ b/libs/databrickscfg/loader.go @@ -0,0 +1,99 @@ +package databrickscfg + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/databricks/bricks/libs/log" + "github.com/databricks/databricks-sdk-go/config" + "gopkg.in/ini.v1" +) + +var ResolveProfileFromHost = profileFromHostLoader{} + +type profileFromHostLoader struct{} + +func (l profileFromHostLoader) Name() string { + return "resolve-profile-from-host" +} + +func (l profileFromHostLoader) Configure(cfg *config.Config) error { + // Skip an attempt to resolve a profile from the host if any authentication + // is already configured (either directly, through environment variables, or + // if a profile was specified). + if cfg.Host == "" || l.isAnyAuthConfigured(cfg) { + return nil + } + + configFile, err := LoadFile(cfg.ConfigFile) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("cannot parse config file: %w", err) + } + + // Normalized version of the configured host. + host := normalizeHost(cfg.Host) + + // Look for sections in the configuration file that match the configured host. + var matching []*ini.Section + for _, section := range configFile.Sections() { + key, err := section.GetKey("host") + if err != nil { + log.Tracef(context.Background(), "section %s: %s", section.Name(), err) + continue + } + + // Ignore this section if the normalized host doesn't match. + if normalizeHost(key.Value()) != host { + continue + } + + matching = append(matching, section) + } + + // If there are no matching sections, we don't do anything. + if len(matching) == 0 { + return nil + } + + // If there are multiple matching sections, let the user know it is impossible + // to unambiguously select a profile to use. + if len(matching) > 1 { + var names []string + for _, section := range matching { + names = append(names, section.Name()) + } + + return fmt.Errorf( + "multiple profiles for host %s (%s): please set DATABRICKS_CONFIG_PROFILE to specify one", + host, + strings.Join(names, ", "), + ) + } + + match := matching[0] + log.Debugf(context.Background(), "Loading profile %s because of host match", match.Name()) + err = config.ConfigAttributes.ResolveFromStringMap(cfg, match.KeysHash()) + if err != nil { + return fmt.Errorf("%s %s profile: %w", configFile.Path(), match.Name(), err) + } + + return nil + +} + +func (l profileFromHostLoader) isAnyAuthConfigured(cfg *config.Config) bool { + for _, a := range config.ConfigAttributes { + if a.Auth == "" { + continue + } + if !a.IsZero(cfg) { + return true + } + } + return false +} diff --git a/libs/databrickscfg/loader_test.go b/libs/databrickscfg/loader_test.go new file mode 100644 index 00000000..59610858 --- /dev/null +++ b/libs/databrickscfg/loader_test.go @@ -0,0 +1,130 @@ +package databrickscfg + +import ( + "testing" + + "github.com/databricks/databricks-sdk-go/config" + "github.com/stretchr/testify/assert" +) + +func TestLoaderSkipsEmptyHost(t *testing.T) { + cfg := config.Config{ + Loaders: []config.Loader{ + ResolveProfileFromHost, + }, + Host: "", + } + + err := cfg.EnsureResolved() + assert.NoError(t, err) +} + +func TestLoaderSkipsExistingAuth(t *testing.T) { + cfg := config.Config{ + Loaders: []config.Loader{ + ResolveProfileFromHost, + }, + Host: "https://foo", + Token: "nonempty means pat auth", + } + + err := cfg.EnsureResolved() + assert.NoError(t, err) +} + +func TestLoaderSkipsNonExistingConfigFile(t *testing.T) { + cfg := config.Config{ + Loaders: []config.Loader{ + ResolveProfileFromHost, + }, + ConfigFile: "idontexist", + Host: "https://default", + } + + err := cfg.EnsureResolved() + assert.NoError(t, err) + assert.Empty(t, cfg.Token) +} + +func TestLoaderErrorsOnInvalidFile(t *testing.T) { + cfg := config.Config{ + Loaders: []config.Loader{ + ResolveProfileFromHost, + }, + ConfigFile: "testdata/badcfg", + Host: "https://default", + } + + err := cfg.EnsureResolved() + assert.ErrorContains(t, err, "unclosed section: ") +} + +func TestLoaderSkipssNoMatchingHost(t *testing.T) { + cfg := config.Config{ + Loaders: []config.Loader{ + ResolveProfileFromHost, + }, + ConfigFile: "testdata/databrickscfg", + Host: "https://noneofthehostsmatch", + } + + err := cfg.EnsureResolved() + assert.NoError(t, err) + assert.Empty(t, cfg.Token) +} + +func TestLoaderConfiguresMatchingHost(t *testing.T) { + cfg := config.Config{ + Loaders: []config.Loader{ + ResolveProfileFromHost, + }, + ConfigFile: "testdata/databrickscfg", + Host: "https://default/?foo=bar", + } + + err := cfg.EnsureResolved() + assert.NoError(t, err) + assert.Equal(t, "default", cfg.Token) +} + +func TestLoaderMatchingHost(t *testing.T) { + cfg := config.Config{ + Loaders: []config.Loader{ + ResolveProfileFromHost, + }, + ConfigFile: "testdata/databrickscfg", + Host: "https://default", + } + + err := cfg.EnsureResolved() + assert.NoError(t, err) + assert.Equal(t, "default", cfg.Token) +} + +func TestLoaderMatchingHostWithQuery(t *testing.T) { + cfg := config.Config{ + Loaders: []config.Loader{ + ResolveProfileFromHost, + }, + ConfigFile: "testdata/databrickscfg", + Host: "https://query/?foo=bar", + } + + err := cfg.EnsureResolved() + assert.NoError(t, err) + assert.Equal(t, "query", cfg.Token) +} + +func TestLoaderErrorsOnMultipleMatches(t *testing.T) { + cfg := config.Config{ + Loaders: []config.Loader{ + ResolveProfileFromHost, + }, + ConfigFile: "testdata/databrickscfg", + Host: "https://foo/bar", + } + + err := cfg.EnsureResolved() + assert.Error(t, err) + assert.ErrorContains(t, err, "multiple profiles for host https://foo (foo1, foo2): ") +} diff --git a/libs/databrickscfg/testdata/badcfg b/libs/databrickscfg/testdata/badcfg new file mode 100644 index 00000000..f749f7ec --- /dev/null +++ b/libs/databrickscfg/testdata/badcfg @@ -0,0 +1 @@ +[[[[[bad diff --git a/libs/databrickscfg/testdata/databrickscfg b/libs/databrickscfg/testdata/databrickscfg new file mode 100644 index 00000000..ad81933e --- /dev/null +++ b/libs/databrickscfg/testdata/databrickscfg @@ -0,0 +1,20 @@ +[DEFAULT] +host = https://default +token = default + +[query] +host = https://query/?o=1234 +token = query + +[nohost] +token = query + +# Duplicate entry for https://foo +[foo1] +host = https://foo +token = foo1 + +# Duplicate entry for https://foo +[foo2] +host = https://foo +token = foo2