Define flags for running jobs and pipelines (#146)

This commit is contained in:
Pieter Noordhuis 2022-12-23 15:17:16 +01:00 committed by GitHub
parent 49aa858b89
commit 8f4461904b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 162 additions and 11 deletions

View File

@ -11,8 +11,77 @@ import (
"github.com/databricks/bricks/bundle/config/resources"
"github.com/databricks/databricks-sdk-go/retries"
"github.com/databricks/databricks-sdk-go/service/jobs"
flag "github.com/spf13/pflag"
)
// JobOptions defines options for running a job.
type JobOptions struct {
dbtCommands []string
jarParams []string
notebookParams map[string]string
pipelineParams map[string]string
pythonNamedParams map[string]string
pythonParams []string
sparkSubmitParams []string
sqlParams map[string]string
}
func (o *JobOptions) Define(fs *flag.FlagSet) {
fs.StringSliceVar(&o.dbtCommands, "dbt-commands", nil, "A list of commands to execute for jobs with DBT tasks.")
fs.StringSliceVar(&o.jarParams, "jar-params", nil, "A list of parameters for jobs with Spark JAR tasks.")
fs.StringToStringVar(&o.notebookParams, "notebook-params", nil, "A map from keys to values for jobs with notebook tasks.")
fs.StringToStringVar(&o.pipelineParams, "pipeline-params", nil, "A map from keys to values for jobs with pipeline tasks.")
fs.StringToStringVar(&o.pythonNamedParams, "python-named-params", nil, "A map from keys to values for jobs with Python wheel tasks.")
fs.StringSliceVar(&o.pythonParams, "python-params", nil, "A list of parameters for jobs with Python tasks.")
fs.StringSliceVar(&o.sparkSubmitParams, "spark-submit-params", nil, "A list of parameters for jobs with Spark submit tasks.")
fs.StringToStringVar(&o.sqlParams, "sql-params", nil, "A map from keys to values for jobs with SQL tasks.")
}
func (o *JobOptions) validatePipelineParams() (*jobs.PipelineParams, error) {
if len(o.pipelineParams) == 0 {
return nil, nil
}
var defaultErr = fmt.Errorf("job run argument --pipeline-params only supports `full_refresh=<bool>`")
v, ok := o.pipelineParams["full_refresh"]
if !ok {
return nil, defaultErr
}
b, err := strconv.ParseBool(v)
if err != nil {
return nil, defaultErr
}
pipelineParams := &jobs.PipelineParams{
FullRefresh: b,
}
return pipelineParams, nil
}
func (o *JobOptions) toPayload(jobID int64) (*jobs.RunNow, error) {
pipelineParams, err := o.validatePipelineParams()
if err != nil {
return nil, err
}
payload := &jobs.RunNow{
JobId: jobID,
DbtCommands: o.dbtCommands,
JarParams: o.jarParams,
NotebookParams: o.notebookParams,
PipelineParams: pipelineParams,
PythonNamedParams: o.pythonNamedParams,
PythonParams: o.pythonParams,
SparkSubmitParams: o.sparkSubmitParams,
SqlParams: o.sqlParams,
}
return payload, nil
}
// Default timeout for waiting for a job run to complete.
var jobRunTimeout time.Duration = 2 * time.Hour
@ -23,7 +92,7 @@ type jobRunner struct {
job *resources.Job
}
func (r *jobRunner) Run(ctx context.Context) error {
func (r *jobRunner) Run(ctx context.Context, opts *Options) error {
jobID, err := strconv.ParseInt(r.job.ID, 10, 64)
if err != nil {
return fmt.Errorf("job ID is not an integer: %s", r.job.ID)
@ -48,10 +117,13 @@ func (r *jobRunner) Run(ctx context.Context) error {
}
}
req, err := opts.Job.toPayload(jobID)
if err != nil {
return err
}
w := r.bundle.WorkspaceClient()
run, err := w.Jobs.RunNowAndWait(ctx, jobs.RunNow{
JobId: jobID,
}, retries.Timeout[jobs.Run](jobRunTimeout), update)
run, err := w.Jobs.RunNowAndWait(ctx, *req, retries.Timeout[jobs.Run](jobRunTimeout), update)
if err != nil {
return err
}

13
bundle/run/options.go Normal file
View File

@ -0,0 +1,13 @@
package run
import flag "github.com/spf13/pflag"
type Options struct {
Job JobOptions
Pipeline PipelineOptions
}
func (o *Options) Define(fs *flag.FlagSet) {
o.Job.Define(fs)
o.Pipeline.Define(fs)
}

View File

@ -4,13 +4,73 @@ import (
"context"
"fmt"
"log"
"strings"
"time"
"github.com/databricks/bricks/bundle"
"github.com/databricks/bricks/bundle/config/resources"
"github.com/databricks/databricks-sdk-go/service/pipelines"
flag "github.com/spf13/pflag"
)
// PipelineOptions defines options for running a pipeline update.
type PipelineOptions struct {
// Perform a full graph update.
RefreshAll bool
// List of tables to update.
Refresh []string
// Perform a full graph reset and recompute.
FullRefreshAll bool
// List of tables to reset and recompute.
FullRefresh []string
}
func (o *PipelineOptions) Define(fs *flag.FlagSet) {
fs.BoolVar(&o.RefreshAll, "refresh-all", false, "Perform a full graph update.")
fs.StringSliceVar(&o.Refresh, "refresh", nil, "List of tables to update.")
fs.BoolVar(&o.FullRefreshAll, "full-refresh-all", false, "Perform a full graph reset and recompute.")
fs.StringSliceVar(&o.FullRefresh, "full-refresh", nil, "List of tables to reset and recompute.")
}
// Validate returns if the combination of options is valid.
func (o *PipelineOptions) Validate() error {
set := []string{}
if o.RefreshAll {
set = append(set, "--refresh-all")
}
if len(o.Refresh) > 0 {
set = append(set, "--refresh")
}
if o.FullRefreshAll {
set = append(set, "--full-refresh-all")
}
if len(o.FullRefresh) > 0 {
set = append(set, "--full-refresh")
}
if len(set) > 1 {
return fmt.Errorf("pipeline run arguments are mutually exclusive (got %s)", strings.Join(set, ", "))
}
return nil
}
func (o *PipelineOptions) toPayload(pipelineID string) (*pipelines.StartUpdate, error) {
if err := o.Validate(); err != nil {
return nil, err
}
payload := &pipelines.StartUpdate{
PipelineId: pipelineID,
// Note: `RefreshAll` is implied if the fields below are not set.
RefreshSelection: o.Refresh,
FullRefresh: o.FullRefreshAll,
FullRefreshSelection: o.FullRefresh,
}
return payload, nil
}
type pipelineRunner struct {
key
@ -18,7 +78,7 @@ type pipelineRunner struct {
pipeline *resources.Pipeline
}
func (r *pipelineRunner) Run(ctx context.Context) error {
func (r *pipelineRunner) Run(ctx context.Context, opts *Options) error {
var prefix = fmt.Sprintf("[INFO] [%s]", r.Key())
var pipelineID = r.pipeline.ID
@ -29,9 +89,12 @@ func (r *pipelineRunner) Run(ctx context.Context) error {
return err
}
res, err := w.Pipelines.StartUpdate(ctx, pipelines.StartUpdate{
PipelineId: pipelineID,
})
req, err := opts.Pipeline.toPayload(pipelineID)
if err != nil {
return err
}
res, err := w.Pipelines.StartUpdate(ctx, *req)
if err != nil {
return err
}

View File

@ -21,7 +21,7 @@ type Runner interface {
Key() string
// Run the underlying worklow.
Run(ctx context.Context) error
Run(ctx context.Context, opts *Options) error
}
// Find locates a runner matching the specified argument.

View File

@ -8,6 +8,8 @@ import (
"github.com/spf13/cobra"
)
var runOptions run.Options
var runCmd = &cobra.Command{
Use: "run [flags] KEY",
Short: "Run a workload (e.g. a job or a pipeline)",
@ -30,7 +32,7 @@ var runCmd = &cobra.Command{
return err
}
err = runner.Run(cmd.Context())
err = runner.Run(cmd.Context(), &runOptions)
if err != nil {
return err
}
@ -40,5 +42,6 @@ var runCmd = &cobra.Command{
}
func init() {
runOptions.Define(runCmd.Flags())
rootCmd.AddCommand(runCmd)
}

2
go.mod
View File

@ -45,7 +45,7 @@ require (
github.com/imdario/mergo v0.3.13
github.com/inconshreveable/mousetrap v1.0.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/spf13/pflag v1.0.5
go.opencensus.io v0.24.0 // indirect
golang.org/x/net v0.1.0 // indirect
golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 // indirect