From ec892aa11c7c5f9a98220d95fa448a4d1433cc23 Mon Sep 17 00:00:00 2001 From: Pieter Noordhuis Date: Wed, 26 Jul 2023 13:17:09 +0200 Subject: [PATCH] Remove dependency on global state for the root command (#606) ## Changes This change is another step towards a CLI without globals. Also see #595. The flags for the root command are now encapsulated in struct types. ## Tests Unit tests pass. --- cmd/bundle/run.go | 4 +-- cmd/root/auth.go | 6 ++-- cmd/root/bundle.go | 6 ++-- cmd/root/bundle_test.go | 51 ++++++++++++++++++------------ cmd/root/io.go | 29 +++++++++++------ cmd/root/logger.go | 54 +++++++++++++++++++------------- cmd/root/progress_logger.go | 37 +++++++++++++++------- cmd/root/progress_logger_test.go | 36 +++++++++++++++++---- cmd/root/root.go | 52 +++++++++++++++++------------- 9 files changed, 177 insertions(+), 98 deletions(-) diff --git a/cmd/bundle/run.go b/cmd/bundle/run.go index 9ca8fe45..439e3522 100644 --- a/cmd/bundle/run.go +++ b/cmd/bundle/run.go @@ -47,7 +47,7 @@ var runCmd = &cobra.Command{ return err } if output != nil { - switch root.OutputType() { + switch root.OutputType(cmd) { case flags.OutputText: resultString, err := output.String() if err != nil { @@ -61,7 +61,7 @@ var runCmd = &cobra.Command{ } cmd.OutOrStdout().Write(b) default: - return fmt.Errorf("unknown output type %s", root.OutputType()) + return fmt.Errorf("unknown output type %s", root.OutputType(cmd)) } } return nil diff --git a/cmd/root/auth.go b/cmd/root/auth.go index ae7f7396..c13f7463 100644 --- a/cmd/root/auth.go +++ b/cmd/root/auth.go @@ -21,9 +21,9 @@ var workspaceClient int var accountClient int var currentUser int -func init() { - RootCmd.PersistentFlags().StringP("profile", "p", "", "~/.databrickscfg profile") - RootCmd.RegisterFlagCompletionFunc("profile", databrickscfg.ProfileCompletion) +func initProfileFlag(cmd *cobra.Command) { + cmd.PersistentFlags().StringP("profile", "p", "", "~/.databrickscfg profile") + cmd.RegisterFlagCompletionFunc("profile", databrickscfg.ProfileCompletion) } func MustAccountClient(cmd *cobra.Command, args []string) error { diff --git a/cmd/root/bundle.go b/cmd/root/bundle.go index 8eab7c2c..8a3b5977 100644 --- a/cmd/root/bundle.go +++ b/cmd/root/bundle.go @@ -118,8 +118,8 @@ func environmentCompletion(cmd *cobra.Command, args []string, toComplete string) return maps.Keys(b.Config.Environments), cobra.ShellCompDirectiveDefault } -func init() { +func initEnvironmentFlag(cmd *cobra.Command) { // To operate in the context of a bundle, all commands must take an "environment" parameter. - RootCmd.PersistentFlags().StringP("environment", "e", "", "bundle environment to use (if applicable)") - RootCmd.RegisterFlagCompletionFunc("environment", environmentCompletion) + cmd.PersistentFlags().StringP("environment", "e", "", "bundle environment to use (if applicable)") + cmd.RegisterFlagCompletionFunc("environment", environmentCompletion) } diff --git a/cmd/root/bundle_test.go b/cmd/root/bundle_test.go index 8dc771bd..4b44e019 100644 --- a/cmd/root/bundle_test.go +++ b/cmd/root/bundle_test.go @@ -9,6 +9,7 @@ import ( "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" + "github.com/spf13/cobra" "github.com/stretchr/testify/assert" ) @@ -27,15 +28,18 @@ func setupDatabricksCfg(t *testing.T) { t.Setenv(homeEnvVar, tempHomeDir) } -func setup(t *testing.T, host string) *bundle.Bundle { +func emptyCommand(t *testing.T) *cobra.Command { + ctx := context.Background() + cmd := &cobra.Command{} + cmd.SetContext(ctx) + initProfileFlag(cmd) + return cmd +} + +func setup(t *testing.T, cmd *cobra.Command, host string) *bundle.Bundle { setupDatabricksCfg(t) - ctx := context.Background() - RootCmd.SetContext(ctx) - _, err := initializeLogger(ctx) - assert.NoError(t, err) - - err = configureBundle(RootCmd, []string{"validate"}, func() (*bundle.Bundle, error) { + err := configureBundle(cmd, []string{"validate"}, func() (*bundle.Bundle, error) { return &bundle.Bundle{ Config: config.Root{ Bundle: config.Bundle{ @@ -48,46 +52,50 @@ func setup(t *testing.T, host string) *bundle.Bundle { }, nil }) assert.NoError(t, err) - - return bundle.Get(RootCmd.Context()) + return bundle.Get(cmd.Context()) } func TestBundleConfigureDefault(t *testing.T) { - b := setup(t, "https://x.com") + cmd := emptyCommand(t) + b := setup(t, cmd, "https://x.com") assert.NotPanics(t, func() { b.WorkspaceClient() }) } func TestBundleConfigureWithMultipleMatches(t *testing.T) { - b := setup(t, "https://a.com") + cmd := emptyCommand(t) + b := setup(t, cmd, "https://a.com") assert.Panics(t, func() { b.WorkspaceClient() }) } func TestBundleConfigureWithNonExistentProfileFlag(t *testing.T) { - RootCmd.Flag("profile").Value.Set("NOEXIST") + cmd := emptyCommand(t) + cmd.Flag("profile").Value.Set("NOEXIST") - b := setup(t, "https://x.com") + b := setup(t, cmd, "https://x.com") assert.PanicsWithError(t, "no matching config profiles found", func() { b.WorkspaceClient() }) } func TestBundleConfigureWithMismatchedProfile(t *testing.T) { - RootCmd.Flag("profile").Value.Set("PROFILE-1") + cmd := emptyCommand(t) + cmd.Flag("profile").Value.Set("PROFILE-1") - b := setup(t, "https://x.com") + b := setup(t, cmd, "https://x.com") assert.PanicsWithError(t, "config host mismatch: profile uses host https://a.com, but CLI configured to use https://x.com", func() { b.WorkspaceClient() }) } func TestBundleConfigureWithCorrectProfile(t *testing.T) { - RootCmd.Flag("profile").Value.Set("PROFILE-1") + cmd := emptyCommand(t) + cmd.Flag("profile").Value.Set("PROFILE-1") - b := setup(t, "https://a.com") + b := setup(t, cmd, "https://a.com") assert.NotPanics(t, func() { b.WorkspaceClient() }) @@ -99,7 +107,8 @@ func TestBundleConfigureWithMismatchedProfileEnvVariable(t *testing.T) { t.Setenv("DATABRICKS_CONFIG_PROFILE", "") }) - b := setup(t, "https://x.com") + cmd := emptyCommand(t) + b := setup(t, cmd, "https://x.com") assert.PanicsWithError(t, "config host mismatch: profile uses host https://a.com, but CLI configured to use https://x.com", func() { b.WorkspaceClient() }) @@ -110,9 +119,11 @@ func TestBundleConfigureWithProfileFlagAndEnvVariable(t *testing.T) { t.Cleanup(func() { t.Setenv("DATABRICKS_CONFIG_PROFILE", "") }) - RootCmd.Flag("profile").Value.Set("PROFILE-1") - b := setup(t, "https://a.com") + cmd := emptyCommand(t) + cmd.Flag("profile").Value.Set("PROFILE-1") + + b := setup(t, cmd, "https://a.com") assert.NotPanics(t, func() { b.WorkspaceClient() }) diff --git a/cmd/root/io.go b/cmd/root/io.go index 93830c80..380c01b1 100644 --- a/cmd/root/io.go +++ b/cmd/root/io.go @@ -10,32 +10,43 @@ import ( const envOutputFormat = "DATABRICKS_OUTPUT_FORMAT" -var outputType flags.Output = flags.OutputText +type outputFlag struct { + output flags.Output +} + +func initOutputFlag(cmd *cobra.Command) *outputFlag { + f := outputFlag{ + output: flags.OutputText, + } -func init() { // Configure defaults from environment, if applicable. // If the provided value is invalid it is ignored. if v, ok := os.LookupEnv(envOutputFormat); ok { - outputType.Set(v) + f.output.Set(v) } - RootCmd.PersistentFlags().VarP(&outputType, "output", "o", "output type: text or json") + cmd.PersistentFlags().VarP(&f.output, "output", "o", "output type: text or json") + return &f } -func OutputType() flags.Output { - return outputType +func OutputType(cmd *cobra.Command) flags.Output { + f, ok := cmd.Flag("output").Value.(*flags.Output) + if !ok { + panic("output flag not defined") + } + + return *f } -func initializeIO(cmd *cobra.Command) error { +func (f *outputFlag) initializeIO(cmd *cobra.Command) error { var template string if cmd.Annotations != nil { // rely on zeroval being an empty string template = cmd.Annotations["template"] } - cmdIO := cmdio.NewIO(outputType, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), template) + cmdIO := cmdio.NewIO(f.output, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), template) ctx := cmdio.InContext(cmd.Context(), cmdIO) cmd.SetContext(ctx) - return nil } diff --git a/cmd/root/logger.go b/cmd/root/logger.go index 89d70760..87f69550 100644 --- a/cmd/root/logger.go +++ b/cmd/root/logger.go @@ -10,6 +10,7 @@ import ( "github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/log" "github.com/fatih/color" + "github.com/spf13/cobra" "golang.org/x/exp/slog" ) @@ -66,12 +67,18 @@ func (l *friendlyHandler) Handle(ctx context.Context, rec slog.Record) error { return err } -func makeLogHandler(opts slog.HandlerOptions) (slog.Handler, error) { - switch logOutput { +type logFlags struct { + file flags.LogFileFlag + level flags.LogLevelFlag + output flags.Output +} + +func (f *logFlags) makeLogHandler(opts slog.HandlerOptions) (slog.Handler, error) { + switch f.output { case flags.OutputJSON: - return opts.NewJSONHandler(logFile.Writer()), nil + return opts.NewJSONHandler(f.file.Writer()), nil case flags.OutputText: - w := logFile.Writer() + w := f.file.Writer() if cmdio.IsTTY(w) { return &friendlyHandler{ Handler: opts.NewTextHandler(w), @@ -81,13 +88,13 @@ func makeLogHandler(opts slog.HandlerOptions) (slog.Handler, error) { return opts.NewTextHandler(w), nil default: - return nil, fmt.Errorf("invalid log output mode: %s", logOutput) + return nil, fmt.Errorf("invalid log output mode: %s", f.output) } } -func initializeLogger(ctx context.Context) (context.Context, error) { +func (f *logFlags) initializeContext(ctx context.Context) (context.Context, error) { opts := slog.HandlerOptions{} - opts.Level = logLevel.Level() + opts.Level = f.level.Level() opts.AddSource = true opts.ReplaceAttr = log.ReplaceAttrFunctions{ log.ReplaceLevelAttr, @@ -95,12 +102,12 @@ func initializeLogger(ctx context.Context) (context.Context, error) { }.ReplaceAttr // Open the underlying log file if the user configured an actual file to log to. - err := logFile.Open() + err := f.file.Open() if err != nil { return nil, err } - handler, err := makeLogHandler(opts) + handler, err := f.makeLogHandler(opts) if err != nil { return nil, err } @@ -109,27 +116,30 @@ func initializeLogger(ctx context.Context) (context.Context, error) { return log.NewContext(ctx, slog.Default()), nil } -var logFile = flags.NewLogFileFlag() -var logLevel = flags.NewLogLevelFlag() -var logOutput = flags.OutputText +func initLogFlags(cmd *cobra.Command) *logFlags { + f := logFlags{ + file: flags.NewLogFileFlag(), + level: flags.NewLogLevelFlag(), + output: flags.OutputText, + } -func init() { // Configure defaults from environment, if applicable. // If the provided value is invalid it is ignored. if v, ok := os.LookupEnv(envLogFile); ok { - logFile.Set(v) + f.file.Set(v) } if v, ok := os.LookupEnv(envLogLevel); ok { - logLevel.Set(v) + f.level.Set(v) } if v, ok := os.LookupEnv(envLogFormat); ok { - logOutput.Set(v) + f.output.Set(v) } - RootCmd.PersistentFlags().Var(&logFile, "log-file", "file to write logs to") - RootCmd.PersistentFlags().Var(&logLevel, "log-level", "log level") - RootCmd.PersistentFlags().Var(&logOutput, "log-format", "log output format (text or json)") - RootCmd.RegisterFlagCompletionFunc("log-file", logFile.Complete) - RootCmd.RegisterFlagCompletionFunc("log-level", logLevel.Complete) - RootCmd.RegisterFlagCompletionFunc("log-format", logOutput.Complete) + cmd.PersistentFlags().Var(&f.file, "log-file", "file to write logs to") + cmd.PersistentFlags().Var(&f.level, "log-level", "log level") + cmd.PersistentFlags().Var(&f.output, "log-format", "log output format (text or json)") + cmd.RegisterFlagCompletionFunc("log-file", f.file.Complete) + cmd.RegisterFlagCompletionFunc("log-level", f.level.Complete) + cmd.RegisterFlagCompletionFunc("log-format", f.output.Complete) + return &f } diff --git a/cmd/root/progress_logger.go b/cmd/root/progress_logger.go index fbd90ebb..bdf52558 100644 --- a/cmd/root/progress_logger.go +++ b/cmd/root/progress_logger.go @@ -7,42 +7,55 @@ import ( "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/flags" + "github.com/spf13/cobra" "golang.org/x/term" ) const envProgressFormat = "DATABRICKS_CLI_PROGRESS_FORMAT" -func resolveModeDefault(format flags.ProgressLogFormat) flags.ProgressLogFormat { - if (logLevel.String() == "disabled" || logFile.String() != "stderr") && +type progressLoggerFlag struct { + flags.ProgressLogFormat + + log *logFlags +} + +func (f *progressLoggerFlag) resolveModeDefault(format flags.ProgressLogFormat) flags.ProgressLogFormat { + if (f.log.level.String() == "disabled" || f.log.file.String() != "stderr") && term.IsTerminal(int(os.Stderr.Fd())) { return flags.ModeInplace } return flags.ModeAppend } -func initializeProgressLogger(ctx context.Context) (context.Context, error) { - if logLevel.String() != "disabled" && logFile.String() == "stderr" && - progressFormat == flags.ModeInplace { +func (f *progressLoggerFlag) initializeContext(ctx context.Context) (context.Context, error) { + if f.log.level.String() != "disabled" && f.log.file.String() == "stderr" && + f.ProgressLogFormat == flags.ModeInplace { return nil, fmt.Errorf("inplace progress logging cannot be used when log-file is stderr") } - format := progressFormat + format := f.ProgressLogFormat if format == flags.ModeDefault { - format = resolveModeDefault(format) + format = f.resolveModeDefault(format) } progressLogger := cmdio.NewLogger(format) return cmdio.NewContext(ctx, progressLogger), nil } -var progressFormat = flags.NewProgressLogFormat() +func initProgressLoggerFlag(cmd *cobra.Command, logFlags *logFlags) *progressLoggerFlag { + f := progressLoggerFlag{ + ProgressLogFormat: flags.NewProgressLogFormat(), + + log: logFlags, + } -func init() { // Configure defaults from environment, if applicable. // If the provided value is invalid it is ignored. if v, ok := os.LookupEnv(envProgressFormat); ok { - progressFormat.Set(v) + f.Set(v) } - RootCmd.PersistentFlags().Var(&progressFormat, "progress-format", "format for progress logs (append, inplace, json)") - RootCmd.RegisterFlagCompletionFunc("progress-format", progressFormat.Complete) + + cmd.PersistentFlags().Var(&f.ProgressLogFormat, "progress-format", "format for progress logs (append, inplace, json)") + cmd.RegisterFlagCompletionFunc("progress-format", f.ProgressLogFormat.Complete) + return &f } diff --git a/cmd/root/progress_logger_test.go b/cmd/root/progress_logger_test.go index 30359257..9dceee8d 100644 --- a/cmd/root/progress_logger_test.go +++ b/cmd/root/progress_logger_test.go @@ -6,38 +6,62 @@ import ( "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/flags" + "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +type progressLoggerTest struct { + *cobra.Command + *logFlags + *progressLoggerFlag +} + +func initializeProgressLoggerTest(t *testing.T) ( + *progressLoggerTest, + *flags.LogLevelFlag, + *flags.LogFileFlag, + *flags.ProgressLogFormat, +) { + plt := &progressLoggerTest{ + Command: &cobra.Command{}, + } + plt.logFlags = initLogFlags(plt.Command) + plt.progressLoggerFlag = initProgressLoggerFlag(plt.Command, plt.logFlags) + return plt, &plt.logFlags.level, &plt.logFlags.file, &plt.progressLoggerFlag.ProgressLogFormat +} + func TestInitializeErrorOnIncompatibleConfig(t *testing.T) { + plt, logLevel, logFile, progressFormat := initializeProgressLoggerTest(t) logLevel.Set("info") logFile.Set("stderr") progressFormat.Set("inplace") - _, err := initializeProgressLogger(context.Background()) + _, err := plt.progressLoggerFlag.initializeContext(context.Background()) assert.ErrorContains(t, err, "inplace progress logging cannot be used when log-file is stderr") } func TestNoErrorOnDisabledLogLevel(t *testing.T) { + plt, logLevel, logFile, progressFormat := initializeProgressLoggerTest(t) logLevel.Set("disabled") logFile.Set("stderr") progressFormat.Set("inplace") - _, err := initializeProgressLogger(context.Background()) + _, err := plt.progressLoggerFlag.initializeContext(context.Background()) assert.NoError(t, err) } func TestNoErrorOnNonStderrLogFile(t *testing.T) { + plt, logLevel, logFile, progressFormat := initializeProgressLoggerTest(t) logLevel.Set("info") logFile.Set("stdout") progressFormat.Set("inplace") - _, err := initializeProgressLogger(context.Background()) + _, err := plt.progressLoggerFlag.initializeContext(context.Background()) assert.NoError(t, err) } func TestDefaultLoggerModeResolution(t *testing.T) { - progressFormat = flags.NewProgressLogFormat() - require.Equal(t, progressFormat, flags.ModeDefault) - ctx, err := initializeProgressLogger(context.Background()) + plt, _, _, progressFormat := initializeProgressLoggerTest(t) + require.Equal(t, *progressFormat, flags.ModeDefault) + ctx, err := plt.progressLoggerFlag.initializeContext(context.Background()) require.NoError(t, err) logger, ok := cmdio.FromContext(ctx) assert.True(t, ok) diff --git a/cmd/root/root.go b/cmd/root/root.go index 663dd645..45fc27f2 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -13,26 +13,34 @@ import ( "golang.org/x/exp/slog" ) -// RootCmd represents the base command when called without any subcommands -var RootCmd = &cobra.Command{ - Use: "databricks", - Short: "Databricks CLI", - Version: build.GetInfo().Version, +func New() *cobra.Command { + cmd := &cobra.Command{ + Use: "databricks", + Short: "Databricks CLI", + Version: build.GetInfo().Version, - // Cobra prints the usage string to stderr if a command returns an error. - // This usage string should only be displayed if an invalid combination of flags - // is specified and not when runtime errors occur (e.g. resource not found). - // The usage string is include in [flagErrorFunc] for flag errors only. - SilenceUsage: true, + // Cobra prints the usage string to stderr if a command returns an error. + // This usage string should only be displayed if an invalid combination of flags + // is specified and not when runtime errors occur (e.g. resource not found). + // The usage string is include in [flagErrorFunc] for flag errors only. + SilenceUsage: true, - // Silence error printing by cobra. Errors are printed through cmdio. - SilenceErrors: true, + // Silence error printing by cobra. Errors are printed through cmdio. + SilenceErrors: true, + } - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + // Initialize flags + logFlags := initLogFlags(cmd) + progressLoggerFlag := initProgressLoggerFlag(cmd, logFlags) + outputFlag := initOutputFlag(cmd) + initProfileFlag(cmd) + initEnvironmentFlag(cmd) + + cmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() // Configure default logger. - ctx, err := initializeLogger(ctx) + ctx, err := logFlags.initializeContext(ctx) if err != nil { return err } @@ -43,7 +51,7 @@ var RootCmd = &cobra.Command{ slog.String("args", strings.Join(os.Args, ", "))) // Configure progress logger - ctx, err = initializeProgressLogger(ctx) + ctx, err = progressLoggerFlag.initializeContext(ctx) if err != nil { return err } @@ -51,7 +59,7 @@ var RootCmd = &cobra.Command{ cmd.SetContext(ctx) // Configure command IO - err = initializeIO(cmd) + err = outputFlag.initializeIO(cmd) if err != nil { return err } @@ -63,7 +71,11 @@ var RootCmd = &cobra.Command{ ctx = withUpstreamInUserAgent(ctx) cmd.SetContext(ctx) return nil - }, + } + + cmd.SetFlagErrorFunc(flagErrorFunc) + cmd.SetVersionTemplate("Databricks CLI v{{.Version}}\n") + return cmd } // Wrap flag errors to include the usage string. @@ -104,7 +116,5 @@ func Execute(cmd *cobra.Command) { } } -func init() { - RootCmd.SetFlagErrorFunc(flagErrorFunc) - RootCmd.SetVersionTemplate("Databricks CLI v{{.Version}}\n") -} +// Keep a global copy until all commands can be initialized. +var RootCmd = New()