Remove dependency on global state for the root command (#606)

## Changes

This change is another step towards a CLI without globals. Also see #595.

The flags for the root command are now encapsulated in struct types.

## Tests

Unit tests pass.
This commit is contained in:
Pieter Noordhuis 2023-07-26 13:17:09 +02:00 committed by GitHub
parent cfff140815
commit ec892aa11c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 177 additions and 98 deletions

View File

@ -47,7 +47,7 @@ var runCmd = &cobra.Command{
return err
}
if output != nil {
switch root.OutputType() {
switch root.OutputType(cmd) {
case flags.OutputText:
resultString, err := output.String()
if err != nil {
@ -61,7 +61,7 @@ var runCmd = &cobra.Command{
}
cmd.OutOrStdout().Write(b)
default:
return fmt.Errorf("unknown output type %s", root.OutputType())
return fmt.Errorf("unknown output type %s", root.OutputType(cmd))
}
}
return nil

View File

@ -21,9 +21,9 @@ var workspaceClient int
var accountClient int
var currentUser int
func init() {
RootCmd.PersistentFlags().StringP("profile", "p", "", "~/.databrickscfg profile")
RootCmd.RegisterFlagCompletionFunc("profile", databrickscfg.ProfileCompletion)
func initProfileFlag(cmd *cobra.Command) {
cmd.PersistentFlags().StringP("profile", "p", "", "~/.databrickscfg profile")
cmd.RegisterFlagCompletionFunc("profile", databrickscfg.ProfileCompletion)
}
func MustAccountClient(cmd *cobra.Command, args []string) error {

View File

@ -118,8 +118,8 @@ func environmentCompletion(cmd *cobra.Command, args []string, toComplete string)
return maps.Keys(b.Config.Environments), cobra.ShellCompDirectiveDefault
}
func init() {
func initEnvironmentFlag(cmd *cobra.Command) {
// To operate in the context of a bundle, all commands must take an "environment" parameter.
RootCmd.PersistentFlags().StringP("environment", "e", "", "bundle environment to use (if applicable)")
RootCmd.RegisterFlagCompletionFunc("environment", environmentCompletion)
cmd.PersistentFlags().StringP("environment", "e", "", "bundle environment to use (if applicable)")
cmd.RegisterFlagCompletionFunc("environment", environmentCompletion)
}

View File

@ -9,6 +9,7 @@ import (
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
)
@ -27,15 +28,18 @@ func setupDatabricksCfg(t *testing.T) {
t.Setenv(homeEnvVar, tempHomeDir)
}
func setup(t *testing.T, host string) *bundle.Bundle {
func emptyCommand(t *testing.T) *cobra.Command {
ctx := context.Background()
cmd := &cobra.Command{}
cmd.SetContext(ctx)
initProfileFlag(cmd)
return cmd
}
func setup(t *testing.T, cmd *cobra.Command, host string) *bundle.Bundle {
setupDatabricksCfg(t)
ctx := context.Background()
RootCmd.SetContext(ctx)
_, err := initializeLogger(ctx)
assert.NoError(t, err)
err = configureBundle(RootCmd, []string{"validate"}, func() (*bundle.Bundle, error) {
err := configureBundle(cmd, []string{"validate"}, func() (*bundle.Bundle, error) {
return &bundle.Bundle{
Config: config.Root{
Bundle: config.Bundle{
@ -48,46 +52,50 @@ func setup(t *testing.T, host string) *bundle.Bundle {
}, nil
})
assert.NoError(t, err)
return bundle.Get(RootCmd.Context())
return bundle.Get(cmd.Context())
}
func TestBundleConfigureDefault(t *testing.T) {
b := setup(t, "https://x.com")
cmd := emptyCommand(t)
b := setup(t, cmd, "https://x.com")
assert.NotPanics(t, func() {
b.WorkspaceClient()
})
}
func TestBundleConfigureWithMultipleMatches(t *testing.T) {
b := setup(t, "https://a.com")
cmd := emptyCommand(t)
b := setup(t, cmd, "https://a.com")
assert.Panics(t, func() {
b.WorkspaceClient()
})
}
func TestBundleConfigureWithNonExistentProfileFlag(t *testing.T) {
RootCmd.Flag("profile").Value.Set("NOEXIST")
cmd := emptyCommand(t)
cmd.Flag("profile").Value.Set("NOEXIST")
b := setup(t, "https://x.com")
b := setup(t, cmd, "https://x.com")
assert.PanicsWithError(t, "no matching config profiles found", func() {
b.WorkspaceClient()
})
}
func TestBundleConfigureWithMismatchedProfile(t *testing.T) {
RootCmd.Flag("profile").Value.Set("PROFILE-1")
cmd := emptyCommand(t)
cmd.Flag("profile").Value.Set("PROFILE-1")
b := setup(t, "https://x.com")
b := setup(t, cmd, "https://x.com")
assert.PanicsWithError(t, "config host mismatch: profile uses host https://a.com, but CLI configured to use https://x.com", func() {
b.WorkspaceClient()
})
}
func TestBundleConfigureWithCorrectProfile(t *testing.T) {
RootCmd.Flag("profile").Value.Set("PROFILE-1")
cmd := emptyCommand(t)
cmd.Flag("profile").Value.Set("PROFILE-1")
b := setup(t, "https://a.com")
b := setup(t, cmd, "https://a.com")
assert.NotPanics(t, func() {
b.WorkspaceClient()
})
@ -99,7 +107,8 @@ func TestBundleConfigureWithMismatchedProfileEnvVariable(t *testing.T) {
t.Setenv("DATABRICKS_CONFIG_PROFILE", "")
})
b := setup(t, "https://x.com")
cmd := emptyCommand(t)
b := setup(t, cmd, "https://x.com")
assert.PanicsWithError(t, "config host mismatch: profile uses host https://a.com, but CLI configured to use https://x.com", func() {
b.WorkspaceClient()
})
@ -110,9 +119,11 @@ func TestBundleConfigureWithProfileFlagAndEnvVariable(t *testing.T) {
t.Cleanup(func() {
t.Setenv("DATABRICKS_CONFIG_PROFILE", "")
})
RootCmd.Flag("profile").Value.Set("PROFILE-1")
b := setup(t, "https://a.com")
cmd := emptyCommand(t)
cmd.Flag("profile").Value.Set("PROFILE-1")
b := setup(t, cmd, "https://a.com")
assert.NotPanics(t, func() {
b.WorkspaceClient()
})

View File

@ -10,32 +10,43 @@ import (
const envOutputFormat = "DATABRICKS_OUTPUT_FORMAT"
var outputType flags.Output = flags.OutputText
type outputFlag struct {
output flags.Output
}
func initOutputFlag(cmd *cobra.Command) *outputFlag {
f := outputFlag{
output: flags.OutputText,
}
func init() {
// Configure defaults from environment, if applicable.
// If the provided value is invalid it is ignored.
if v, ok := os.LookupEnv(envOutputFormat); ok {
outputType.Set(v)
f.output.Set(v)
}
RootCmd.PersistentFlags().VarP(&outputType, "output", "o", "output type: text or json")
cmd.PersistentFlags().VarP(&f.output, "output", "o", "output type: text or json")
return &f
}
func OutputType() flags.Output {
return outputType
func OutputType(cmd *cobra.Command) flags.Output {
f, ok := cmd.Flag("output").Value.(*flags.Output)
if !ok {
panic("output flag not defined")
}
func initializeIO(cmd *cobra.Command) error {
return *f
}
func (f *outputFlag) initializeIO(cmd *cobra.Command) error {
var template string
if cmd.Annotations != nil {
// rely on zeroval being an empty string
template = cmd.Annotations["template"]
}
cmdIO := cmdio.NewIO(outputType, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), template)
cmdIO := cmdio.NewIO(f.output, cmd.InOrStdin(), cmd.OutOrStdout(), cmd.ErrOrStderr(), template)
ctx := cmdio.InContext(cmd.Context(), cmdIO)
cmd.SetContext(ctx)
return nil
}

View File

@ -10,6 +10,7 @@ import (
"github.com/databricks/cli/libs/flags"
"github.com/databricks/cli/libs/log"
"github.com/fatih/color"
"github.com/spf13/cobra"
"golang.org/x/exp/slog"
)
@ -66,12 +67,18 @@ func (l *friendlyHandler) Handle(ctx context.Context, rec slog.Record) error {
return err
}
func makeLogHandler(opts slog.HandlerOptions) (slog.Handler, error) {
switch logOutput {
type logFlags struct {
file flags.LogFileFlag
level flags.LogLevelFlag
output flags.Output
}
func (f *logFlags) makeLogHandler(opts slog.HandlerOptions) (slog.Handler, error) {
switch f.output {
case flags.OutputJSON:
return opts.NewJSONHandler(logFile.Writer()), nil
return opts.NewJSONHandler(f.file.Writer()), nil
case flags.OutputText:
w := logFile.Writer()
w := f.file.Writer()
if cmdio.IsTTY(w) {
return &friendlyHandler{
Handler: opts.NewTextHandler(w),
@ -81,13 +88,13 @@ func makeLogHandler(opts slog.HandlerOptions) (slog.Handler, error) {
return opts.NewTextHandler(w), nil
default:
return nil, fmt.Errorf("invalid log output mode: %s", logOutput)
return nil, fmt.Errorf("invalid log output mode: %s", f.output)
}
}
func initializeLogger(ctx context.Context) (context.Context, error) {
func (f *logFlags) initializeContext(ctx context.Context) (context.Context, error) {
opts := slog.HandlerOptions{}
opts.Level = logLevel.Level()
opts.Level = f.level.Level()
opts.AddSource = true
opts.ReplaceAttr = log.ReplaceAttrFunctions{
log.ReplaceLevelAttr,
@ -95,12 +102,12 @@ func initializeLogger(ctx context.Context) (context.Context, error) {
}.ReplaceAttr
// Open the underlying log file if the user configured an actual file to log to.
err := logFile.Open()
err := f.file.Open()
if err != nil {
return nil, err
}
handler, err := makeLogHandler(opts)
handler, err := f.makeLogHandler(opts)
if err != nil {
return nil, err
}
@ -109,27 +116,30 @@ func initializeLogger(ctx context.Context) (context.Context, error) {
return log.NewContext(ctx, slog.Default()), nil
}
var logFile = flags.NewLogFileFlag()
var logLevel = flags.NewLogLevelFlag()
var logOutput = flags.OutputText
func initLogFlags(cmd *cobra.Command) *logFlags {
f := logFlags{
file: flags.NewLogFileFlag(),
level: flags.NewLogLevelFlag(),
output: flags.OutputText,
}
func init() {
// Configure defaults from environment, if applicable.
// If the provided value is invalid it is ignored.
if v, ok := os.LookupEnv(envLogFile); ok {
logFile.Set(v)
f.file.Set(v)
}
if v, ok := os.LookupEnv(envLogLevel); ok {
logLevel.Set(v)
f.level.Set(v)
}
if v, ok := os.LookupEnv(envLogFormat); ok {
logOutput.Set(v)
f.output.Set(v)
}
RootCmd.PersistentFlags().Var(&logFile, "log-file", "file to write logs to")
RootCmd.PersistentFlags().Var(&logLevel, "log-level", "log level")
RootCmd.PersistentFlags().Var(&logOutput, "log-format", "log output format (text or json)")
RootCmd.RegisterFlagCompletionFunc("log-file", logFile.Complete)
RootCmd.RegisterFlagCompletionFunc("log-level", logLevel.Complete)
RootCmd.RegisterFlagCompletionFunc("log-format", logOutput.Complete)
cmd.PersistentFlags().Var(&f.file, "log-file", "file to write logs to")
cmd.PersistentFlags().Var(&f.level, "log-level", "log level")
cmd.PersistentFlags().Var(&f.output, "log-format", "log output format (text or json)")
cmd.RegisterFlagCompletionFunc("log-file", f.file.Complete)
cmd.RegisterFlagCompletionFunc("log-level", f.level.Complete)
cmd.RegisterFlagCompletionFunc("log-format", f.output.Complete)
return &f
}

View File

@ -7,42 +7,55 @@ import (
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/flags"
"github.com/spf13/cobra"
"golang.org/x/term"
)
const envProgressFormat = "DATABRICKS_CLI_PROGRESS_FORMAT"
func resolveModeDefault(format flags.ProgressLogFormat) flags.ProgressLogFormat {
if (logLevel.String() == "disabled" || logFile.String() != "stderr") &&
type progressLoggerFlag struct {
flags.ProgressLogFormat
log *logFlags
}
func (f *progressLoggerFlag) resolveModeDefault(format flags.ProgressLogFormat) flags.ProgressLogFormat {
if (f.log.level.String() == "disabled" || f.log.file.String() != "stderr") &&
term.IsTerminal(int(os.Stderr.Fd())) {
return flags.ModeInplace
}
return flags.ModeAppend
}
func initializeProgressLogger(ctx context.Context) (context.Context, error) {
if logLevel.String() != "disabled" && logFile.String() == "stderr" &&
progressFormat == flags.ModeInplace {
func (f *progressLoggerFlag) initializeContext(ctx context.Context) (context.Context, error) {
if f.log.level.String() != "disabled" && f.log.file.String() == "stderr" &&
f.ProgressLogFormat == flags.ModeInplace {
return nil, fmt.Errorf("inplace progress logging cannot be used when log-file is stderr")
}
format := progressFormat
format := f.ProgressLogFormat
if format == flags.ModeDefault {
format = resolveModeDefault(format)
format = f.resolveModeDefault(format)
}
progressLogger := cmdio.NewLogger(format)
return cmdio.NewContext(ctx, progressLogger), nil
}
var progressFormat = flags.NewProgressLogFormat()
func initProgressLoggerFlag(cmd *cobra.Command, logFlags *logFlags) *progressLoggerFlag {
f := progressLoggerFlag{
ProgressLogFormat: flags.NewProgressLogFormat(),
log: logFlags,
}
func init() {
// Configure defaults from environment, if applicable.
// If the provided value is invalid it is ignored.
if v, ok := os.LookupEnv(envProgressFormat); ok {
progressFormat.Set(v)
f.Set(v)
}
RootCmd.PersistentFlags().Var(&progressFormat, "progress-format", "format for progress logs (append, inplace, json)")
RootCmd.RegisterFlagCompletionFunc("progress-format", progressFormat.Complete)
cmd.PersistentFlags().Var(&f.ProgressLogFormat, "progress-format", "format for progress logs (append, inplace, json)")
cmd.RegisterFlagCompletionFunc("progress-format", f.ProgressLogFormat.Complete)
return &f
}

View File

@ -6,38 +6,62 @@ import (
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/flags"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type progressLoggerTest struct {
*cobra.Command
*logFlags
*progressLoggerFlag
}
func initializeProgressLoggerTest(t *testing.T) (
*progressLoggerTest,
*flags.LogLevelFlag,
*flags.LogFileFlag,
*flags.ProgressLogFormat,
) {
plt := &progressLoggerTest{
Command: &cobra.Command{},
}
plt.logFlags = initLogFlags(plt.Command)
plt.progressLoggerFlag = initProgressLoggerFlag(plt.Command, plt.logFlags)
return plt, &plt.logFlags.level, &plt.logFlags.file, &plt.progressLoggerFlag.ProgressLogFormat
}
func TestInitializeErrorOnIncompatibleConfig(t *testing.T) {
plt, logLevel, logFile, progressFormat := initializeProgressLoggerTest(t)
logLevel.Set("info")
logFile.Set("stderr")
progressFormat.Set("inplace")
_, err := initializeProgressLogger(context.Background())
_, err := plt.progressLoggerFlag.initializeContext(context.Background())
assert.ErrorContains(t, err, "inplace progress logging cannot be used when log-file is stderr")
}
func TestNoErrorOnDisabledLogLevel(t *testing.T) {
plt, logLevel, logFile, progressFormat := initializeProgressLoggerTest(t)
logLevel.Set("disabled")
logFile.Set("stderr")
progressFormat.Set("inplace")
_, err := initializeProgressLogger(context.Background())
_, err := plt.progressLoggerFlag.initializeContext(context.Background())
assert.NoError(t, err)
}
func TestNoErrorOnNonStderrLogFile(t *testing.T) {
plt, logLevel, logFile, progressFormat := initializeProgressLoggerTest(t)
logLevel.Set("info")
logFile.Set("stdout")
progressFormat.Set("inplace")
_, err := initializeProgressLogger(context.Background())
_, err := plt.progressLoggerFlag.initializeContext(context.Background())
assert.NoError(t, err)
}
func TestDefaultLoggerModeResolution(t *testing.T) {
progressFormat = flags.NewProgressLogFormat()
require.Equal(t, progressFormat, flags.ModeDefault)
ctx, err := initializeProgressLogger(context.Background())
plt, _, _, progressFormat := initializeProgressLoggerTest(t)
require.Equal(t, *progressFormat, flags.ModeDefault)
ctx, err := plt.progressLoggerFlag.initializeContext(context.Background())
require.NoError(t, err)
logger, ok := cmdio.FromContext(ctx)
assert.True(t, ok)

View File

@ -13,8 +13,8 @@ import (
"golang.org/x/exp/slog"
)
// RootCmd represents the base command when called without any subcommands
var RootCmd = &cobra.Command{
func New() *cobra.Command {
cmd := &cobra.Command{
Use: "databricks",
Short: "Databricks CLI",
Version: build.GetInfo().Version,
@ -27,12 +27,20 @@ var RootCmd = &cobra.Command{
// Silence error printing by cobra. Errors are printed through cmdio.
SilenceErrors: true,
}
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
// Initialize flags
logFlags := initLogFlags(cmd)
progressLoggerFlag := initProgressLoggerFlag(cmd, logFlags)
outputFlag := initOutputFlag(cmd)
initProfileFlag(cmd)
initEnvironmentFlag(cmd)
cmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
// Configure default logger.
ctx, err := initializeLogger(ctx)
ctx, err := logFlags.initializeContext(ctx)
if err != nil {
return err
}
@ -43,7 +51,7 @@ var RootCmd = &cobra.Command{
slog.String("args", strings.Join(os.Args, ", ")))
// Configure progress logger
ctx, err = initializeProgressLogger(ctx)
ctx, err = progressLoggerFlag.initializeContext(ctx)
if err != nil {
return err
}
@ -51,7 +59,7 @@ var RootCmd = &cobra.Command{
cmd.SetContext(ctx)
// Configure command IO
err = initializeIO(cmd)
err = outputFlag.initializeIO(cmd)
if err != nil {
return err
}
@ -63,7 +71,11 @@ var RootCmd = &cobra.Command{
ctx = withUpstreamInUserAgent(ctx)
cmd.SetContext(ctx)
return nil
},
}
cmd.SetFlagErrorFunc(flagErrorFunc)
cmd.SetVersionTemplate("Databricks CLI v{{.Version}}\n")
return cmd
}
// Wrap flag errors to include the usage string.
@ -104,7 +116,5 @@ func Execute(cmd *cobra.Command) {
}
}
func init() {
RootCmd.SetFlagErrorFunc(flagErrorFunc)
RootCmd.SetVersionTemplate("Databricks CLI v{{.Version}}\n")
}
// Keep a global copy until all commands can be initialized.
var RootCmd = New()