From de363faa5347c2ce98c0377d8be931ff96fa8908 Mon Sep 17 00:00:00 2001 From: Andrew Nester Date: Wed, 7 Feb 2024 11:27:13 +0100 Subject: [PATCH] Make sure grouped flags are added to the command flag set (#1180) ## Changes Make sure grouped flags are added to the command flag set ## Tests Added regression tests --- bundle/run/options.go | 12 ++++++++---- libs/cmdgroup/command.go | 31 ++++++++++++++++++++++++++++--- libs/cmdgroup/command_test.go | 12 ++++++++++-- libs/cmdgroup/template.go | 4 ++-- 4 files changed, 48 insertions(+), 11 deletions(-) diff --git a/bundle/run/options.go b/bundle/run/options.go index 580612d0..4e50788a 100644 --- a/bundle/run/options.go +++ b/bundle/run/options.go @@ -12,15 +12,19 @@ type Options struct { } func (o *Options) Define(cmd *cobra.Command) { - wrappedCmd := cmdgroup.NewCommandWithGroupFlag(cmd) - jobGroup := wrappedCmd.AddFlagGroup("Job") + jobGroup := cmdgroup.NewFlagGroup("Job") o.Job.DefineJobOptions(jobGroup.FlagSet()) - jobTaskGroup := wrappedCmd.AddFlagGroup("Job Task") + jobTaskGroup := cmdgroup.NewFlagGroup("Job Task") jobTaskGroup.SetDescription(`Note: please prefer use of job-level parameters (--param) over task-level parameters. For more information, see https://docs.databricks.com/en/workflows/jobs/create-run-jobs.html#pass-parameters-to-a-databricks-job-task`) o.Job.DefineTaskOptions(jobTaskGroup.FlagSet()) - pipelineGroup := wrappedCmd.AddFlagGroup("Pipeline") + pipelineGroup := cmdgroup.NewFlagGroup("Pipeline") o.Pipeline.Define(pipelineGroup.FlagSet()) + + wrappedCmd := cmdgroup.NewCommandWithGroupFlag(cmd) + wrappedCmd.AddFlagGroup(jobGroup) + wrappedCmd.AddFlagGroup(jobTaskGroup) + wrappedCmd.AddFlagGroup(pipelineGroup) } diff --git a/libs/cmdgroup/command.go b/libs/cmdgroup/command.go index 19c9af16..a2a77693 100644 --- a/libs/cmdgroup/command.go +++ b/libs/cmdgroup/command.go @@ -23,6 +23,24 @@ func (c *CommandWithGroupFlag) FlagGroups() []*FlagGroup { return c.flagGroups } +func (c *CommandWithGroupFlag) NonGroupedFlags() *pflag.FlagSet { + nonGrouped := pflag.NewFlagSet("non-grouped", pflag.ContinueOnError) + c.cmd.LocalFlags().VisitAll(func(f *pflag.Flag) { + for _, fg := range c.flagGroups { + if fg.Has(f) { + return + } + } + nonGrouped.AddFlag(f) + }) + + return nonGrouped +} + +func (c *CommandWithGroupFlag) HasNonGroupedFlags() bool { + return c.NonGroupedFlags().HasFlags() +} + func NewCommandWithGroupFlag(cmd *cobra.Command) *CommandWithGroupFlag { cmdWithFlagGroups := &CommandWithGroupFlag{cmd: cmd, flagGroups: make([]*FlagGroup, 0)} cmd.SetUsageFunc(func(c *cobra.Command) error { @@ -36,10 +54,9 @@ func NewCommandWithGroupFlag(cmd *cobra.Command) *CommandWithGroupFlag { return cmdWithFlagGroups } -func (c *CommandWithGroupFlag) AddFlagGroup(name string) *FlagGroup { - fg := &FlagGroup{name: name, flagSet: pflag.NewFlagSet(name, pflag.ContinueOnError)} +func (c *CommandWithGroupFlag) AddFlagGroup(fg *FlagGroup) { c.flagGroups = append(c.flagGroups, fg) - return fg + c.cmd.Flags().AddFlagSet(fg.FlagSet()) } type FlagGroup struct { @@ -48,6 +65,10 @@ type FlagGroup struct { flagSet *pflag.FlagSet } +func NewFlagGroup(name string) *FlagGroup { + return &FlagGroup{name: name, flagSet: pflag.NewFlagSet(name, pflag.ContinueOnError)} +} + func (c *FlagGroup) Name() string { return c.name } @@ -64,6 +85,10 @@ func (c *FlagGroup) FlagSet() *pflag.FlagSet { return c.flagSet } +func (c *FlagGroup) Has(f *pflag.Flag) bool { + return c.flagSet.Lookup(f.Name) != nil +} + var templateFuncs = template.FuncMap{ "trim": strings.TrimSpace, "trimRightSpace": trimRightSpace, diff --git a/libs/cmdgroup/command_test.go b/libs/cmdgroup/command_test.go index 2eae31d1..9122c780 100644 --- a/libs/cmdgroup/command_test.go +++ b/libs/cmdgroup/command_test.go @@ -18,15 +18,17 @@ func TestCommandFlagGrouping(t *testing.T) { } wrappedCmd := NewCommandWithGroupFlag(cmd) - jobGroup := wrappedCmd.AddFlagGroup("Job") + jobGroup := NewFlagGroup("Job") fs := jobGroup.FlagSet() fs.String("job-name", "", "Name of the job") fs.String("job-type", "", "Type of the job") + wrappedCmd.AddFlagGroup(jobGroup) - pipelineGroup := wrappedCmd.AddFlagGroup("Pipeline") + pipelineGroup := NewFlagGroup("Pipeline") fs = pipelineGroup.FlagSet() fs.String("pipeline-name", "", "Name of the pipeline") fs.String("pipeline-type", "", "Type of the pipeline") + wrappedCmd.AddFlagGroup(pipelineGroup) cmd.Flags().BoolP("bool", "b", false, "Bool flag") @@ -48,4 +50,10 @@ Pipeline Flags: Flags: -b, --bool Bool flag` require.Equal(t, expected, buf.String()) + + require.NotNil(t, cmd.Flags().Lookup("job-name")) + require.NotNil(t, cmd.Flags().Lookup("job-type")) + require.NotNil(t, cmd.Flags().Lookup("pipeline-name")) + require.NotNil(t, cmd.Flags().Lookup("pipeline-type")) + require.NotNil(t, cmd.Flags().Lookup("bool")) } diff --git a/libs/cmdgroup/template.go b/libs/cmdgroup/template.go index aac967b0..5c1be48f 100644 --- a/libs/cmdgroup/template.go +++ b/libs/cmdgroup/template.go @@ -7,8 +7,8 @@ const usageTemplate = `Usage:{{if .Command.Runnable}} {{.Description}}{{end}} {{.FlagSet.FlagUsages | trimTrailingWhitespaces}} {{end}} -{{if .Command.HasAvailableLocalFlags}}Flags: -{{.Command.LocalFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .Command.HasAvailableInheritedFlags}} +{{if .HasNonGroupedFlags}}Flags: +{{.NonGroupedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .Command.HasAvailableInheritedFlags}} Global Flags: {{.Command.InheritedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}`