Make sure grouped flags are added to the command flag set (#1180)

## Changes
Make sure grouped flags are added to the command flag set

## Tests
Added regression tests
This commit is contained in:
Andrew Nester 2024-02-07 11:27:13 +01:00 committed by GitHub
parent 0b5fdcc346
commit de363faa53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 48 additions and 11 deletions

View File

@ -12,15 +12,19 @@ type Options struct {
} }
func (o *Options) Define(cmd *cobra.Command) { func (o *Options) Define(cmd *cobra.Command) {
wrappedCmd := cmdgroup.NewCommandWithGroupFlag(cmd) jobGroup := cmdgroup.NewFlagGroup("Job")
jobGroup := wrappedCmd.AddFlagGroup("Job")
o.Job.DefineJobOptions(jobGroup.FlagSet()) o.Job.DefineJobOptions(jobGroup.FlagSet())
jobTaskGroup := wrappedCmd.AddFlagGroup("Job Task") jobTaskGroup := cmdgroup.NewFlagGroup("Job Task")
jobTaskGroup.SetDescription(`Note: please prefer use of job-level parameters (--param) over task-level parameters. jobTaskGroup.SetDescription(`Note: please prefer use of job-level parameters (--param) over task-level parameters.
For more information, see https://docs.databricks.com/en/workflows/jobs/create-run-jobs.html#pass-parameters-to-a-databricks-job-task`) For more information, see https://docs.databricks.com/en/workflows/jobs/create-run-jobs.html#pass-parameters-to-a-databricks-job-task`)
o.Job.DefineTaskOptions(jobTaskGroup.FlagSet()) o.Job.DefineTaskOptions(jobTaskGroup.FlagSet())
pipelineGroup := wrappedCmd.AddFlagGroup("Pipeline") pipelineGroup := cmdgroup.NewFlagGroup("Pipeline")
o.Pipeline.Define(pipelineGroup.FlagSet()) o.Pipeline.Define(pipelineGroup.FlagSet())
wrappedCmd := cmdgroup.NewCommandWithGroupFlag(cmd)
wrappedCmd.AddFlagGroup(jobGroup)
wrappedCmd.AddFlagGroup(jobTaskGroup)
wrappedCmd.AddFlagGroup(pipelineGroup)
} }

View File

@ -23,6 +23,24 @@ func (c *CommandWithGroupFlag) FlagGroups() []*FlagGroup {
return c.flagGroups return c.flagGroups
} }
func (c *CommandWithGroupFlag) NonGroupedFlags() *pflag.FlagSet {
nonGrouped := pflag.NewFlagSet("non-grouped", pflag.ContinueOnError)
c.cmd.LocalFlags().VisitAll(func(f *pflag.Flag) {
for _, fg := range c.flagGroups {
if fg.Has(f) {
return
}
}
nonGrouped.AddFlag(f)
})
return nonGrouped
}
func (c *CommandWithGroupFlag) HasNonGroupedFlags() bool {
return c.NonGroupedFlags().HasFlags()
}
func NewCommandWithGroupFlag(cmd *cobra.Command) *CommandWithGroupFlag { func NewCommandWithGroupFlag(cmd *cobra.Command) *CommandWithGroupFlag {
cmdWithFlagGroups := &CommandWithGroupFlag{cmd: cmd, flagGroups: make([]*FlagGroup, 0)} cmdWithFlagGroups := &CommandWithGroupFlag{cmd: cmd, flagGroups: make([]*FlagGroup, 0)}
cmd.SetUsageFunc(func(c *cobra.Command) error { cmd.SetUsageFunc(func(c *cobra.Command) error {
@ -36,10 +54,9 @@ func NewCommandWithGroupFlag(cmd *cobra.Command) *CommandWithGroupFlag {
return cmdWithFlagGroups return cmdWithFlagGroups
} }
func (c *CommandWithGroupFlag) AddFlagGroup(name string) *FlagGroup { func (c *CommandWithGroupFlag) AddFlagGroup(fg *FlagGroup) {
fg := &FlagGroup{name: name, flagSet: pflag.NewFlagSet(name, pflag.ContinueOnError)}
c.flagGroups = append(c.flagGroups, fg) c.flagGroups = append(c.flagGroups, fg)
return fg c.cmd.Flags().AddFlagSet(fg.FlagSet())
} }
type FlagGroup struct { type FlagGroup struct {
@ -48,6 +65,10 @@ type FlagGroup struct {
flagSet *pflag.FlagSet flagSet *pflag.FlagSet
} }
func NewFlagGroup(name string) *FlagGroup {
return &FlagGroup{name: name, flagSet: pflag.NewFlagSet(name, pflag.ContinueOnError)}
}
func (c *FlagGroup) Name() string { func (c *FlagGroup) Name() string {
return c.name return c.name
} }
@ -64,6 +85,10 @@ func (c *FlagGroup) FlagSet() *pflag.FlagSet {
return c.flagSet return c.flagSet
} }
func (c *FlagGroup) Has(f *pflag.Flag) bool {
return c.flagSet.Lookup(f.Name) != nil
}
var templateFuncs = template.FuncMap{ var templateFuncs = template.FuncMap{
"trim": strings.TrimSpace, "trim": strings.TrimSpace,
"trimRightSpace": trimRightSpace, "trimRightSpace": trimRightSpace,

View File

@ -18,15 +18,17 @@ func TestCommandFlagGrouping(t *testing.T) {
} }
wrappedCmd := NewCommandWithGroupFlag(cmd) wrappedCmd := NewCommandWithGroupFlag(cmd)
jobGroup := wrappedCmd.AddFlagGroup("Job") jobGroup := NewFlagGroup("Job")
fs := jobGroup.FlagSet() fs := jobGroup.FlagSet()
fs.String("job-name", "", "Name of the job") fs.String("job-name", "", "Name of the job")
fs.String("job-type", "", "Type of the job") fs.String("job-type", "", "Type of the job")
wrappedCmd.AddFlagGroup(jobGroup)
pipelineGroup := wrappedCmd.AddFlagGroup("Pipeline") pipelineGroup := NewFlagGroup("Pipeline")
fs = pipelineGroup.FlagSet() fs = pipelineGroup.FlagSet()
fs.String("pipeline-name", "", "Name of the pipeline") fs.String("pipeline-name", "", "Name of the pipeline")
fs.String("pipeline-type", "", "Type of the pipeline") fs.String("pipeline-type", "", "Type of the pipeline")
wrappedCmd.AddFlagGroup(pipelineGroup)
cmd.Flags().BoolP("bool", "b", false, "Bool flag") cmd.Flags().BoolP("bool", "b", false, "Bool flag")
@ -48,4 +50,10 @@ Pipeline Flags:
Flags: Flags:
-b, --bool Bool flag` -b, --bool Bool flag`
require.Equal(t, expected, buf.String()) require.Equal(t, expected, buf.String())
require.NotNil(t, cmd.Flags().Lookup("job-name"))
require.NotNil(t, cmd.Flags().Lookup("job-type"))
require.NotNil(t, cmd.Flags().Lookup("pipeline-name"))
require.NotNil(t, cmd.Flags().Lookup("pipeline-type"))
require.NotNil(t, cmd.Flags().Lookup("bool"))
} }

View File

@ -7,8 +7,8 @@ const usageTemplate = `Usage:{{if .Command.Runnable}}
{{.Description}}{{end}} {{.Description}}{{end}}
{{.FlagSet.FlagUsages | trimTrailingWhitespaces}} {{.FlagSet.FlagUsages | trimTrailingWhitespaces}}
{{end}} {{end}}
{{if .Command.HasAvailableLocalFlags}}Flags: {{if .HasNonGroupedFlags}}Flags:
{{.Command.LocalFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .Command.HasAvailableInheritedFlags}} {{.NonGroupedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .Command.HasAvailableInheritedFlags}}
Global Flags: Global Flags:
{{.Command.InheritedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}` {{.Command.InheritedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}`