diff --git a/cmd/auth/describe.go b/cmd/auth/describe.go index faaf64f8f..74a691787 100644 --- a/cmd/auth/describe.go +++ b/cmd/auth/describe.go @@ -7,6 +7,7 @@ import ( "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/command" "github.com/databricks/cli/libs/flags" "github.com/databricks/databricks-sdk-go/config" "github.com/spf13/cobra" @@ -57,7 +58,7 @@ func newDescribeCommand() *cobra.Command { var err error status, err = getAuthStatus(cmd, args, showSensitive, func(cmd *cobra.Command, args []string) (*config.Config, bool, error) { isAccount, err := root.MustAnyClient(cmd, args) - return root.ConfigUsed(cmd.Context()), isAccount, err + return command.ConfigUsed(cmd.Context()), isAccount, err }) if err != nil { return err diff --git a/cmd/root/auth.go b/cmd/root/auth.go index e2dac68cc..21b8c8a96 100644 --- a/cmd/root/auth.go +++ b/cmd/root/auth.go @@ -7,6 +7,7 @@ import ( "net/http" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/command" "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/config" @@ -18,7 +19,6 @@ import ( var ( workspaceClient int accountClient int - configUsed int ) type ErrNoWorkspaceProfiles struct { @@ -119,7 +119,7 @@ func MustAccountClient(cmd *cobra.Command, args []string) error { } ctx := cmd.Context() - ctx = context.WithValue(ctx, &configUsed, cfg) + ctx = command.SetConfigUsed(ctx, cfg) cmd.SetContext(ctx) profiler := profile.GetProfiler(ctx) @@ -202,7 +202,7 @@ func MustWorkspaceClient(cmd *cobra.Command, args []string) error { } ctx := cmd.Context() - ctx = context.WithValue(ctx, &configUsed, cfg) + ctx = command.SetConfigUsed(ctx, cfg) cmd.SetContext(ctx) // Try to load a bundle configuration if we're allowed to by the caller (see `./auth_options.go`). @@ -213,7 +213,7 @@ func MustWorkspaceClient(cmd *cobra.Command, args []string) error { } if b != nil { - ctx = context.WithValue(ctx, &configUsed, b.Config.Workspace.Config()) + ctx = command.SetConfigUsed(ctx, b.Config.Workspace.Config()) cmd.SetContext(ctx) client, err := b.WorkspaceClientE() if err != nil { @@ -336,11 +336,3 @@ func AccountClient(ctx context.Context) *databricks.AccountClient { } return a } - -func ConfigUsed(ctx context.Context) *config.Config { - cfg, ok := ctx.Value(&configUsed).(*config.Config) - if !ok { - panic("cannot get *config.Config. Please report it as a bug") - } - return cfg -} diff --git a/cmd/root/bundle.go b/cmd/root/bundle.go index b40803707..d86f9a673 100644 --- a/cmd/root/bundle.go +++ b/cmd/root/bundle.go @@ -6,6 +6,7 @@ import ( "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/env" "github.com/databricks/cli/bundle/phases" + "github.com/databricks/cli/libs/command" "github.com/databricks/cli/libs/diag" envlib "github.com/databricks/cli/libs/env" "github.com/spf13/cobra" @@ -102,7 +103,7 @@ func configureBundle(cmd *cobra.Command, b *bundle.Bundle) (*bundle.Bundle, diag if err != nil { return b, diags.Extend(diag.FromErr(err)) } - ctx = context.WithValue(ctx, &configUsed, client.Config) + ctx = command.SetConfigUsed(ctx, client.Config) cmd.SetContext(ctx) return b, diags diff --git a/cmd/root/bundle_test.go b/cmd/root/bundle_test.go index 3517b02e4..5871b0ae9 100644 --- a/cmd/root/bundle_test.go +++ b/cmd/root/bundle_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/databricks/cli/internal/testutil" + "github.com/databricks/cli/libs/command" "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -78,7 +79,7 @@ func TestBundleConfigureDefault(t *testing.T) { err := setupWithHost(t, cmd, "https://x.com") require.NoError(t, err) - assert.Equal(t, "https://x.com", ConfigUsed(cmd.Context()).Host) + assert.Equal(t, "https://x.com", command.ConfigUsed(cmd.Context()).Host) } func TestBundleConfigureWithMultipleMatches(t *testing.T) { @@ -120,8 +121,8 @@ func TestBundleConfigureWithCorrectProfile(t *testing.T) { err = setupWithHost(t, cmd, "https://a.com") require.NoError(t, err) - assert.Equal(t, "https://a.com", ConfigUsed(cmd.Context()).Host) - assert.Equal(t, "PROFILE-1", ConfigUsed(cmd.Context()).Profile) + assert.Equal(t, "https://a.com", command.ConfigUsed(cmd.Context()).Host) + assert.Equal(t, "PROFILE-1", command.ConfigUsed(cmd.Context()).Profile) } func TestBundleConfigureWithMismatchedProfileEnvVariable(t *testing.T) { @@ -144,8 +145,8 @@ func TestBundleConfigureWithProfileFlagAndEnvVariable(t *testing.T) { err = setupWithHost(t, cmd, "https://a.com") require.NoError(t, err) - assert.Equal(t, "https://a.com", ConfigUsed(cmd.Context()).Host) - assert.Equal(t, "PROFILE-1", ConfigUsed(cmd.Context()).Profile) + assert.Equal(t, "https://a.com", command.ConfigUsed(cmd.Context()).Host) + assert.Equal(t, "PROFILE-1", command.ConfigUsed(cmd.Context()).Profile) } func TestBundleConfigureProfileDefault(t *testing.T) { @@ -156,9 +157,9 @@ func TestBundleConfigureProfileDefault(t *testing.T) { err := setupWithProfile(t, cmd, "PROFILE-1") require.NoError(t, err) - assert.Equal(t, "https://a.com", ConfigUsed(cmd.Context()).Host) - assert.Equal(t, "a", ConfigUsed(cmd.Context()).Token) - assert.Equal(t, "PROFILE-1", ConfigUsed(cmd.Context()).Profile) + assert.Equal(t, "https://a.com", command.ConfigUsed(cmd.Context()).Host) + assert.Equal(t, "a", command.ConfigUsed(cmd.Context()).Token) + assert.Equal(t, "PROFILE-1", command.ConfigUsed(cmd.Context()).Profile) } func TestBundleConfigureProfileFlag(t *testing.T) { @@ -171,9 +172,9 @@ func TestBundleConfigureProfileFlag(t *testing.T) { err = setupWithProfile(t, cmd, "PROFILE-1") require.NoError(t, err) - assert.Equal(t, "https://a.com", ConfigUsed(cmd.Context()).Host) - assert.Equal(t, "b", ConfigUsed(cmd.Context()).Token) - assert.Equal(t, "PROFILE-2", ConfigUsed(cmd.Context()).Profile) + assert.Equal(t, "https://a.com", command.ConfigUsed(cmd.Context()).Host) + assert.Equal(t, "b", command.ConfigUsed(cmd.Context()).Token) + assert.Equal(t, "PROFILE-2", command.ConfigUsed(cmd.Context()).Profile) } func TestBundleConfigureProfileEnvVariable(t *testing.T) { @@ -185,9 +186,9 @@ func TestBundleConfigureProfileEnvVariable(t *testing.T) { err := setupWithProfile(t, cmd, "PROFILE-1") require.NoError(t, err) - assert.Equal(t, "https://a.com", ConfigUsed(cmd.Context()).Host) - assert.Equal(t, "b", ConfigUsed(cmd.Context()).Token) - assert.Equal(t, "PROFILE-2", ConfigUsed(cmd.Context()).Profile) + assert.Equal(t, "https://a.com", command.ConfigUsed(cmd.Context()).Host) + assert.Equal(t, "b", command.ConfigUsed(cmd.Context()).Token) + assert.Equal(t, "PROFILE-2", command.ConfigUsed(cmd.Context()).Profile) } func TestBundleConfigureProfileFlagAndEnvVariable(t *testing.T) { @@ -201,9 +202,9 @@ func TestBundleConfigureProfileFlagAndEnvVariable(t *testing.T) { err = setupWithProfile(t, cmd, "PROFILE-1") require.NoError(t, err) - assert.Equal(t, "https://a.com", ConfigUsed(cmd.Context()).Host) - assert.Equal(t, "b", ConfigUsed(cmd.Context()).Token) - assert.Equal(t, "PROFILE-2", ConfigUsed(cmd.Context()).Profile) + assert.Equal(t, "https://a.com", command.ConfigUsed(cmd.Context()).Host) + assert.Equal(t, "b", command.ConfigUsed(cmd.Context()).Token) + assert.Equal(t, "PROFILE-2", command.ConfigUsed(cmd.Context()).Profile) } func TestTargetFlagFull(t *testing.T) { diff --git a/libs/command/config_used.go b/libs/command/config_used.go new file mode 100644 index 000000000..e507343bd --- /dev/null +++ b/libs/command/config_used.go @@ -0,0 +1,19 @@ +package command + +import ( + "context" + + "github.com/databricks/databricks-sdk-go/config" +) + +func SetConfigUsed(ctx context.Context, cfg *config.Config) context.Context { + return context.WithValue(ctx, configUsedKey, cfg) +} + +func ConfigUsed(ctx context.Context) *config.Config { + cfg, ok := ctx.Value(configUsedKey).(*config.Config) + if !ok { + panic("cannot get *config.Config. Please report it as a bug") + } + return cfg +} diff --git a/libs/command/context.go b/libs/command/context.go new file mode 100644 index 000000000..8f84f01e6 --- /dev/null +++ b/libs/command/context.go @@ -0,0 +1,13 @@ +package command + +// key is a package-local type to use for context keys. +// +// Using an unexported type for context keys prevents key collisions across +// packages since external packages cannot create values of this type. +type key int + +const ( + // configUsedKey is the context key for the auth configuration used to run the + // command. + configUsedKey = key(2) +)