From 8f4461904b2b4b9e5971742d9a7b7c35a47f5aca Mon Sep 17 00:00:00 2001 From: Pieter Noordhuis Date: Fri, 23 Dec 2022 15:17:16 +0100 Subject: [PATCH] Define flags for running jobs and pipelines (#146) --- bundle/run/job.go | 80 +++++++++++++++++++++++++++++++++++++++--- bundle/run/options.go | 13 +++++++ bundle/run/pipeline.go | 71 ++++++++++++++++++++++++++++++++++--- bundle/run/runner.go | 2 +- cmd/bundle/run.go | 5 ++- go.mod | 2 +- 6 files changed, 162 insertions(+), 11 deletions(-) create mode 100644 bundle/run/options.go diff --git a/bundle/run/job.go b/bundle/run/job.go index 45925a75..f67cdcca 100644 --- a/bundle/run/job.go +++ b/bundle/run/job.go @@ -11,8 +11,77 @@ import ( "github.com/databricks/bricks/bundle/config/resources" "github.com/databricks/databricks-sdk-go/retries" "github.com/databricks/databricks-sdk-go/service/jobs" + flag "github.com/spf13/pflag" ) +// JobOptions defines options for running a job. +type JobOptions struct { + dbtCommands []string + jarParams []string + notebookParams map[string]string + pipelineParams map[string]string + pythonNamedParams map[string]string + pythonParams []string + sparkSubmitParams []string + sqlParams map[string]string +} + +func (o *JobOptions) Define(fs *flag.FlagSet) { + fs.StringSliceVar(&o.dbtCommands, "dbt-commands", nil, "A list of commands to execute for jobs with DBT tasks.") + fs.StringSliceVar(&o.jarParams, "jar-params", nil, "A list of parameters for jobs with Spark JAR tasks.") + fs.StringToStringVar(&o.notebookParams, "notebook-params", nil, "A map from keys to values for jobs with notebook tasks.") + fs.StringToStringVar(&o.pipelineParams, "pipeline-params", nil, "A map from keys to values for jobs with pipeline tasks.") + fs.StringToStringVar(&o.pythonNamedParams, "python-named-params", nil, "A map from keys to values for jobs with Python wheel tasks.") + fs.StringSliceVar(&o.pythonParams, "python-params", nil, "A list of parameters for jobs with Python tasks.") + fs.StringSliceVar(&o.sparkSubmitParams, "spark-submit-params", nil, "A list of parameters for jobs with Spark submit tasks.") + fs.StringToStringVar(&o.sqlParams, "sql-params", nil, "A map from keys to values for jobs with SQL tasks.") +} + +func (o *JobOptions) validatePipelineParams() (*jobs.PipelineParams, error) { + if len(o.pipelineParams) == 0 { + return nil, nil + } + + var defaultErr = fmt.Errorf("job run argument --pipeline-params only supports `full_refresh=`") + v, ok := o.pipelineParams["full_refresh"] + if !ok { + return nil, defaultErr + } + + b, err := strconv.ParseBool(v) + if err != nil { + return nil, defaultErr + } + + pipelineParams := &jobs.PipelineParams{ + FullRefresh: b, + } + + return pipelineParams, nil +} + +func (o *JobOptions) toPayload(jobID int64) (*jobs.RunNow, error) { + pipelineParams, err := o.validatePipelineParams() + if err != nil { + return nil, err + } + + payload := &jobs.RunNow{ + JobId: jobID, + + DbtCommands: o.dbtCommands, + JarParams: o.jarParams, + NotebookParams: o.notebookParams, + PipelineParams: pipelineParams, + PythonNamedParams: o.pythonNamedParams, + PythonParams: o.pythonParams, + SparkSubmitParams: o.sparkSubmitParams, + SqlParams: o.sqlParams, + } + + return payload, nil +} + // Default timeout for waiting for a job run to complete. var jobRunTimeout time.Duration = 2 * time.Hour @@ -23,7 +92,7 @@ type jobRunner struct { job *resources.Job } -func (r *jobRunner) Run(ctx context.Context) error { +func (r *jobRunner) Run(ctx context.Context, opts *Options) error { jobID, err := strconv.ParseInt(r.job.ID, 10, 64) if err != nil { return fmt.Errorf("job ID is not an integer: %s", r.job.ID) @@ -48,10 +117,13 @@ func (r *jobRunner) Run(ctx context.Context) error { } } + req, err := opts.Job.toPayload(jobID) + if err != nil { + return err + } + w := r.bundle.WorkspaceClient() - run, err := w.Jobs.RunNowAndWait(ctx, jobs.RunNow{ - JobId: jobID, - }, retries.Timeout[jobs.Run](jobRunTimeout), update) + run, err := w.Jobs.RunNowAndWait(ctx, *req, retries.Timeout[jobs.Run](jobRunTimeout), update) if err != nil { return err } diff --git a/bundle/run/options.go b/bundle/run/options.go new file mode 100644 index 00000000..7a550af8 --- /dev/null +++ b/bundle/run/options.go @@ -0,0 +1,13 @@ +package run + +import flag "github.com/spf13/pflag" + +type Options struct { + Job JobOptions + Pipeline PipelineOptions +} + +func (o *Options) Define(fs *flag.FlagSet) { + o.Job.Define(fs) + o.Pipeline.Define(fs) +} diff --git a/bundle/run/pipeline.go b/bundle/run/pipeline.go index 012925f9..d222d292 100644 --- a/bundle/run/pipeline.go +++ b/bundle/run/pipeline.go @@ -4,13 +4,73 @@ import ( "context" "fmt" "log" + "strings" "time" "github.com/databricks/bricks/bundle" "github.com/databricks/bricks/bundle/config/resources" "github.com/databricks/databricks-sdk-go/service/pipelines" + flag "github.com/spf13/pflag" ) +// PipelineOptions defines options for running a pipeline update. +type PipelineOptions struct { + // Perform a full graph update. + RefreshAll bool + + // List of tables to update. + Refresh []string + + // Perform a full graph reset and recompute. + FullRefreshAll bool + + // List of tables to reset and recompute. + FullRefresh []string +} + +func (o *PipelineOptions) Define(fs *flag.FlagSet) { + fs.BoolVar(&o.RefreshAll, "refresh-all", false, "Perform a full graph update.") + fs.StringSliceVar(&o.Refresh, "refresh", nil, "List of tables to update.") + fs.BoolVar(&o.FullRefreshAll, "full-refresh-all", false, "Perform a full graph reset and recompute.") + fs.StringSliceVar(&o.FullRefresh, "full-refresh", nil, "List of tables to reset and recompute.") +} + +// Validate returns if the combination of options is valid. +func (o *PipelineOptions) Validate() error { + set := []string{} + if o.RefreshAll { + set = append(set, "--refresh-all") + } + if len(o.Refresh) > 0 { + set = append(set, "--refresh") + } + if o.FullRefreshAll { + set = append(set, "--full-refresh-all") + } + if len(o.FullRefresh) > 0 { + set = append(set, "--full-refresh") + } + if len(set) > 1 { + return fmt.Errorf("pipeline run arguments are mutually exclusive (got %s)", strings.Join(set, ", ")) + } + return nil +} + +func (o *PipelineOptions) toPayload(pipelineID string) (*pipelines.StartUpdate, error) { + if err := o.Validate(); err != nil { + return nil, err + } + payload := &pipelines.StartUpdate{ + PipelineId: pipelineID, + + // Note: `RefreshAll` is implied if the fields below are not set. + RefreshSelection: o.Refresh, + FullRefresh: o.FullRefreshAll, + FullRefreshSelection: o.FullRefresh, + } + return payload, nil +} + type pipelineRunner struct { key @@ -18,7 +78,7 @@ type pipelineRunner struct { pipeline *resources.Pipeline } -func (r *pipelineRunner) Run(ctx context.Context) error { +func (r *pipelineRunner) Run(ctx context.Context, opts *Options) error { var prefix = fmt.Sprintf("[INFO] [%s]", r.Key()) var pipelineID = r.pipeline.ID @@ -29,9 +89,12 @@ func (r *pipelineRunner) Run(ctx context.Context) error { return err } - res, err := w.Pipelines.StartUpdate(ctx, pipelines.StartUpdate{ - PipelineId: pipelineID, - }) + req, err := opts.Pipeline.toPayload(pipelineID) + if err != nil { + return err + } + + res, err := w.Pipelines.StartUpdate(ctx, *req) if err != nil { return err } diff --git a/bundle/run/runner.go b/bundle/run/runner.go index e9c1aadd..2c6bd8c8 100644 --- a/bundle/run/runner.go +++ b/bundle/run/runner.go @@ -21,7 +21,7 @@ type Runner interface { Key() string // Run the underlying worklow. - Run(ctx context.Context) error + Run(ctx context.Context, opts *Options) error } // Find locates a runner matching the specified argument. diff --git a/cmd/bundle/run.go b/cmd/bundle/run.go index 93ba1823..b1cd616b 100644 --- a/cmd/bundle/run.go +++ b/cmd/bundle/run.go @@ -8,6 +8,8 @@ import ( "github.com/spf13/cobra" ) +var runOptions run.Options + var runCmd = &cobra.Command{ Use: "run [flags] KEY", Short: "Run a workload (e.g. a job or a pipeline)", @@ -30,7 +32,7 @@ var runCmd = &cobra.Command{ return err } - err = runner.Run(cmd.Context()) + err = runner.Run(cmd.Context(), &runOptions) if err != nil { return err } @@ -40,5 +42,6 @@ var runCmd = &cobra.Command{ } func init() { + runOptions.Define(runCmd.Flags()) rootCmd.AddCommand(runCmd) } diff --git a/go.mod b/go.mod index 8f0363a0..d61d56ae 100644 --- a/go.mod +++ b/go.mod @@ -45,7 +45,7 @@ require ( github.com/imdario/mergo v0.3.13 github.com/inconshreveable/mousetrap v1.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/spf13/pflag v1.0.5 // indirect + github.com/spf13/pflag v1.0.5 go.opencensus.io v0.24.0 // indirect golang.org/x/net v0.1.0 // indirect golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 // indirect