diff --git a/bundle/run/job_options.go b/bundle/run/job_options.go index 209591d7..c359e79e 100644 --- a/bundle/run/job_options.go +++ b/bundle/run/job_options.go @@ -27,8 +27,11 @@ type JobOptions struct { jobParams map[string]string } -func (o *JobOptions) Define(fs *flag.FlagSet) { - // Define task parameters flags. +func (o *JobOptions) DefineJobOptions(fs *flag.FlagSet) { + fs.StringToStringVar(&o.jobParams, "params", nil, "comma separated k=v pairs for job parameters") +} + +func (o *JobOptions) DefineTaskOptions(fs *flag.FlagSet) { fs.StringSliceVar(&o.dbtCommands, "dbt-commands", nil, "A list of commands to execute for jobs with DBT tasks.") fs.StringSliceVar(&o.jarParams, "jar-params", nil, "A list of parameters for jobs with Spark JAR tasks.") fs.StringToStringVar(&o.notebookParams, "notebook-params", nil, "A map from keys to values for jobs with notebook tasks.") @@ -37,9 +40,6 @@ func (o *JobOptions) Define(fs *flag.FlagSet) { fs.StringSliceVar(&o.pythonParams, "python-params", nil, "A list of parameters for jobs with Python tasks.") fs.StringSliceVar(&o.sparkSubmitParams, "spark-submit-params", nil, "A list of parameters for jobs with Spark submit tasks.") fs.StringToStringVar(&o.sqlParams, "sql-params", nil, "A map from keys to values for jobs with SQL tasks.") - - // Define job parameters flag. - fs.StringToStringVar(&o.jobParams, "params", nil, "comma separated k=v pairs for job parameters") } func (o *JobOptions) hasTaskParametersConfigured() bool { diff --git a/bundle/run/job_options_test.go b/bundle/run/job_options_test.go index 822771d8..08e18d95 100644 --- a/bundle/run/job_options_test.go +++ b/bundle/run/job_options_test.go @@ -13,7 +13,8 @@ import ( func setupJobOptions(t *testing.T) (*flag.FlagSet, *JobOptions) { var fs flag.FlagSet var opts JobOptions - opts.Define(&fs) + opts.DefineJobOptions(&fs) + opts.DefineTaskOptions(&fs) return &fs, &opts } diff --git a/bundle/run/options.go b/bundle/run/options.go index 3194fb32..580612d0 100644 --- a/bundle/run/options.go +++ b/bundle/run/options.go @@ -1,7 +1,8 @@ package run import ( - flag "github.com/spf13/pflag" + "github.com/databricks/cli/libs/cmdgroup" + "github.com/spf13/cobra" ) type Options struct { @@ -10,7 +11,16 @@ type Options struct { NoWait bool } -func (o *Options) Define(fs *flag.FlagSet) { - o.Job.Define(fs) - o.Pipeline.Define(fs) +func (o *Options) Define(cmd *cobra.Command) { + wrappedCmd := cmdgroup.NewCommandWithGroupFlag(cmd) + jobGroup := wrappedCmd.AddFlagGroup("Job") + o.Job.DefineJobOptions(jobGroup.FlagSet()) + + jobTaskGroup := wrappedCmd.AddFlagGroup("Job Task") + jobTaskGroup.SetDescription(`Note: please prefer use of job-level parameters (--param) over task-level parameters. + For more information, see https://docs.databricks.com/en/workflows/jobs/create-run-jobs.html#pass-parameters-to-a-databricks-job-task`) + o.Job.DefineTaskOptions(jobTaskGroup.FlagSet()) + + pipelineGroup := wrappedCmd.AddFlagGroup("Pipeline") + o.Pipeline.Define(pipelineGroup.FlagSet()) } diff --git a/cmd/bundle/run.go b/cmd/bundle/run.go index c9e35aa3..a4b10658 100644 --- a/cmd/bundle/run.go +++ b/cmd/bundle/run.go @@ -24,7 +24,7 @@ func newRunCommand() *cobra.Command { } var runOptions run.Options - runOptions.Define(cmd.Flags()) + runOptions.Define(cmd) var noWait bool cmd.Flags().BoolVar(&noWait, "no-wait", false, "Don't wait for the run to complete.") diff --git a/libs/cmdgroup/command.go b/libs/cmdgroup/command.go new file mode 100644 index 00000000..19c9af16 --- /dev/null +++ b/libs/cmdgroup/command.go @@ -0,0 +1,83 @@ +package cmdgroup + +import ( + "io" + "strings" + "text/template" + "unicode" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +type CommandWithGroupFlag struct { + cmd *cobra.Command + flagGroups []*FlagGroup +} + +func (c *CommandWithGroupFlag) Command() *cobra.Command { + return c.cmd +} + +func (c *CommandWithGroupFlag) FlagGroups() []*FlagGroup { + return c.flagGroups +} + +func NewCommandWithGroupFlag(cmd *cobra.Command) *CommandWithGroupFlag { + cmdWithFlagGroups := &CommandWithGroupFlag{cmd: cmd, flagGroups: make([]*FlagGroup, 0)} + cmd.SetUsageFunc(func(c *cobra.Command) error { + err := tmpl(c.OutOrStderr(), c.UsageTemplate(), cmdWithFlagGroups) + if err != nil { + c.PrintErrln(err) + } + return nil + }) + cmd.SetUsageTemplate(usageTemplate) + return cmdWithFlagGroups +} + +func (c *CommandWithGroupFlag) AddFlagGroup(name string) *FlagGroup { + fg := &FlagGroup{name: name, flagSet: pflag.NewFlagSet(name, pflag.ContinueOnError)} + c.flagGroups = append(c.flagGroups, fg) + return fg +} + +type FlagGroup struct { + name string + description string + flagSet *pflag.FlagSet +} + +func (c *FlagGroup) Name() string { + return c.name +} + +func (c *FlagGroup) Description() string { + return c.description +} + +func (c *FlagGroup) SetDescription(description string) { + c.description = description +} + +func (c *FlagGroup) FlagSet() *pflag.FlagSet { + return c.flagSet +} + +var templateFuncs = template.FuncMap{ + "trim": strings.TrimSpace, + "trimRightSpace": trimRightSpace, + "trimTrailingWhitespaces": trimRightSpace, +} + +func trimRightSpace(s string) string { + return strings.TrimRightFunc(s, unicode.IsSpace) +} + +// tmpl executes the given template text on data, writing the result to w. +func tmpl(w io.Writer, text string, data interface{}) error { + t := template.New("top") + t.Funcs(templateFuncs) + template.Must(t.Parse(text)) + return t.Execute(w, data) +} diff --git a/libs/cmdgroup/command_test.go b/libs/cmdgroup/command_test.go new file mode 100644 index 00000000..2eae31d1 --- /dev/null +++ b/libs/cmdgroup/command_test.go @@ -0,0 +1,51 @@ +package cmdgroup + +import ( + "bytes" + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/require" +) + +func TestCommandFlagGrouping(t *testing.T) { + cmd := &cobra.Command{ + Use: "test [flags]", + Short: "test command", + Run: func(cmd *cobra.Command, args []string) { + // Do nothing + }, + } + + wrappedCmd := NewCommandWithGroupFlag(cmd) + jobGroup := wrappedCmd.AddFlagGroup("Job") + fs := jobGroup.FlagSet() + fs.String("job-name", "", "Name of the job") + fs.String("job-type", "", "Type of the job") + + pipelineGroup := wrappedCmd.AddFlagGroup("Pipeline") + fs = pipelineGroup.FlagSet() + fs.String("pipeline-name", "", "Name of the pipeline") + fs.String("pipeline-type", "", "Type of the pipeline") + + cmd.Flags().BoolP("bool", "b", false, "Bool flag") + + buf := bytes.NewBuffer(nil) + cmd.SetOutput(buf) + cmd.Usage() + + expected := `Usage: + test [flags] + +Job Flags: + --job-name string Name of the job + --job-type string Type of the job + +Pipeline Flags: + --pipeline-name string Name of the pipeline + --pipeline-type string Type of the pipeline + +Flags: + -b, --bool Bool flag` + require.Equal(t, expected, buf.String()) +} diff --git a/libs/cmdgroup/template.go b/libs/cmdgroup/template.go new file mode 100644 index 00000000..aac967b0 --- /dev/null +++ b/libs/cmdgroup/template.go @@ -0,0 +1,14 @@ +package cmdgroup + +const usageTemplate = `Usage:{{if .Command.Runnable}} + {{.Command.UseLine}}{{end}} +{{range .FlagGroups}} +{{.Name}} Flags:{{if not (eq .Description "")}} + {{.Description}}{{end}} +{{.FlagSet.FlagUsages | trimTrailingWhitespaces}} +{{end}} +{{if .Command.HasAvailableLocalFlags}}Flags: +{{.Command.LocalFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .Command.HasAvailableInheritedFlags}} + +Global Flags: +{{.Command.InheritedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}`