mirror of https://github.com/databricks/cli.git
Define flags for running jobs and pipelines (#146)
This commit is contained in:
parent
49aa858b89
commit
8f4461904b
|
@ -11,8 +11,77 @@ import (
|
|||
"github.com/databricks/bricks/bundle/config/resources"
|
||||
"github.com/databricks/databricks-sdk-go/retries"
|
||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||
flag "github.com/spf13/pflag"
|
||||
)
|
||||
|
||||
// JobOptions defines options for running a job.
|
||||
type JobOptions struct {
|
||||
dbtCommands []string
|
||||
jarParams []string
|
||||
notebookParams map[string]string
|
||||
pipelineParams map[string]string
|
||||
pythonNamedParams map[string]string
|
||||
pythonParams []string
|
||||
sparkSubmitParams []string
|
||||
sqlParams map[string]string
|
||||
}
|
||||
|
||||
func (o *JobOptions) Define(fs *flag.FlagSet) {
|
||||
fs.StringSliceVar(&o.dbtCommands, "dbt-commands", nil, "A list of commands to execute for jobs with DBT tasks.")
|
||||
fs.StringSliceVar(&o.jarParams, "jar-params", nil, "A list of parameters for jobs with Spark JAR tasks.")
|
||||
fs.StringToStringVar(&o.notebookParams, "notebook-params", nil, "A map from keys to values for jobs with notebook tasks.")
|
||||
fs.StringToStringVar(&o.pipelineParams, "pipeline-params", nil, "A map from keys to values for jobs with pipeline tasks.")
|
||||
fs.StringToStringVar(&o.pythonNamedParams, "python-named-params", nil, "A map from keys to values for jobs with Python wheel tasks.")
|
||||
fs.StringSliceVar(&o.pythonParams, "python-params", nil, "A list of parameters for jobs with Python tasks.")
|
||||
fs.StringSliceVar(&o.sparkSubmitParams, "spark-submit-params", nil, "A list of parameters for jobs with Spark submit tasks.")
|
||||
fs.StringToStringVar(&o.sqlParams, "sql-params", nil, "A map from keys to values for jobs with SQL tasks.")
|
||||
}
|
||||
|
||||
func (o *JobOptions) validatePipelineParams() (*jobs.PipelineParams, error) {
|
||||
if len(o.pipelineParams) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var defaultErr = fmt.Errorf("job run argument --pipeline-params only supports `full_refresh=<bool>`")
|
||||
v, ok := o.pipelineParams["full_refresh"]
|
||||
if !ok {
|
||||
return nil, defaultErr
|
||||
}
|
||||
|
||||
b, err := strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return nil, defaultErr
|
||||
}
|
||||
|
||||
pipelineParams := &jobs.PipelineParams{
|
||||
FullRefresh: b,
|
||||
}
|
||||
|
||||
return pipelineParams, nil
|
||||
}
|
||||
|
||||
func (o *JobOptions) toPayload(jobID int64) (*jobs.RunNow, error) {
|
||||
pipelineParams, err := o.validatePipelineParams()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := &jobs.RunNow{
|
||||
JobId: jobID,
|
||||
|
||||
DbtCommands: o.dbtCommands,
|
||||
JarParams: o.jarParams,
|
||||
NotebookParams: o.notebookParams,
|
||||
PipelineParams: pipelineParams,
|
||||
PythonNamedParams: o.pythonNamedParams,
|
||||
PythonParams: o.pythonParams,
|
||||
SparkSubmitParams: o.sparkSubmitParams,
|
||||
SqlParams: o.sqlParams,
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// Default timeout for waiting for a job run to complete.
|
||||
var jobRunTimeout time.Duration = 2 * time.Hour
|
||||
|
||||
|
@ -23,7 +92,7 @@ type jobRunner struct {
|
|||
job *resources.Job
|
||||
}
|
||||
|
||||
func (r *jobRunner) Run(ctx context.Context) error {
|
||||
func (r *jobRunner) Run(ctx context.Context, opts *Options) error {
|
||||
jobID, err := strconv.ParseInt(r.job.ID, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("job ID is not an integer: %s", r.job.ID)
|
||||
|
@ -48,10 +117,13 @@ func (r *jobRunner) Run(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
req, err := opts.Job.toPayload(jobID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w := r.bundle.WorkspaceClient()
|
||||
run, err := w.Jobs.RunNowAndWait(ctx, jobs.RunNow{
|
||||
JobId: jobID,
|
||||
}, retries.Timeout[jobs.Run](jobRunTimeout), update)
|
||||
run, err := w.Jobs.RunNowAndWait(ctx, *req, retries.Timeout[jobs.Run](jobRunTimeout), update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
package run
|
||||
|
||||
import flag "github.com/spf13/pflag"
|
||||
|
||||
type Options struct {
|
||||
Job JobOptions
|
||||
Pipeline PipelineOptions
|
||||
}
|
||||
|
||||
func (o *Options) Define(fs *flag.FlagSet) {
|
||||
o.Job.Define(fs)
|
||||
o.Pipeline.Define(fs)
|
||||
}
|
|
@ -4,13 +4,73 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/databricks/bricks/bundle"
|
||||
"github.com/databricks/bricks/bundle/config/resources"
|
||||
"github.com/databricks/databricks-sdk-go/service/pipelines"
|
||||
flag "github.com/spf13/pflag"
|
||||
)
|
||||
|
||||
// PipelineOptions defines options for running a pipeline update.
|
||||
type PipelineOptions struct {
|
||||
// Perform a full graph update.
|
||||
RefreshAll bool
|
||||
|
||||
// List of tables to update.
|
||||
Refresh []string
|
||||
|
||||
// Perform a full graph reset and recompute.
|
||||
FullRefreshAll bool
|
||||
|
||||
// List of tables to reset and recompute.
|
||||
FullRefresh []string
|
||||
}
|
||||
|
||||
func (o *PipelineOptions) Define(fs *flag.FlagSet) {
|
||||
fs.BoolVar(&o.RefreshAll, "refresh-all", false, "Perform a full graph update.")
|
||||
fs.StringSliceVar(&o.Refresh, "refresh", nil, "List of tables to update.")
|
||||
fs.BoolVar(&o.FullRefreshAll, "full-refresh-all", false, "Perform a full graph reset and recompute.")
|
||||
fs.StringSliceVar(&o.FullRefresh, "full-refresh", nil, "List of tables to reset and recompute.")
|
||||
}
|
||||
|
||||
// Validate returns if the combination of options is valid.
|
||||
func (o *PipelineOptions) Validate() error {
|
||||
set := []string{}
|
||||
if o.RefreshAll {
|
||||
set = append(set, "--refresh-all")
|
||||
}
|
||||
if len(o.Refresh) > 0 {
|
||||
set = append(set, "--refresh")
|
||||
}
|
||||
if o.FullRefreshAll {
|
||||
set = append(set, "--full-refresh-all")
|
||||
}
|
||||
if len(o.FullRefresh) > 0 {
|
||||
set = append(set, "--full-refresh")
|
||||
}
|
||||
if len(set) > 1 {
|
||||
return fmt.Errorf("pipeline run arguments are mutually exclusive (got %s)", strings.Join(set, ", "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *PipelineOptions) toPayload(pipelineID string) (*pipelines.StartUpdate, error) {
|
||||
if err := o.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := &pipelines.StartUpdate{
|
||||
PipelineId: pipelineID,
|
||||
|
||||
// Note: `RefreshAll` is implied if the fields below are not set.
|
||||
RefreshSelection: o.Refresh,
|
||||
FullRefresh: o.FullRefreshAll,
|
||||
FullRefreshSelection: o.FullRefresh,
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
type pipelineRunner struct {
|
||||
key
|
||||
|
||||
|
@ -18,7 +78,7 @@ type pipelineRunner struct {
|
|||
pipeline *resources.Pipeline
|
||||
}
|
||||
|
||||
func (r *pipelineRunner) Run(ctx context.Context) error {
|
||||
func (r *pipelineRunner) Run(ctx context.Context, opts *Options) error {
|
||||
var prefix = fmt.Sprintf("[INFO] [%s]", r.Key())
|
||||
var pipelineID = r.pipeline.ID
|
||||
|
||||
|
@ -29,9 +89,12 @@ func (r *pipelineRunner) Run(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
res, err := w.Pipelines.StartUpdate(ctx, pipelines.StartUpdate{
|
||||
PipelineId: pipelineID,
|
||||
})
|
||||
req, err := opts.Pipeline.toPayload(pipelineID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := w.Pipelines.StartUpdate(ctx, *req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ type Runner interface {
|
|||
Key() string
|
||||
|
||||
// Run the underlying worklow.
|
||||
Run(ctx context.Context) error
|
||||
Run(ctx context.Context, opts *Options) error
|
||||
}
|
||||
|
||||
// Find locates a runner matching the specified argument.
|
||||
|
|
|
@ -8,6 +8,8 @@ import (
|
|||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var runOptions run.Options
|
||||
|
||||
var runCmd = &cobra.Command{
|
||||
Use: "run [flags] KEY",
|
||||
Short: "Run a workload (e.g. a job or a pipeline)",
|
||||
|
@ -30,7 +32,7 @@ var runCmd = &cobra.Command{
|
|||
return err
|
||||
}
|
||||
|
||||
err = runner.Run(cmd.Context())
|
||||
err = runner.Run(cmd.Context(), &runOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -40,5 +42,6 @@ var runCmd = &cobra.Command{
|
|||
}
|
||||
|
||||
func init() {
|
||||
runOptions.Define(runCmd.Flags())
|
||||
rootCmd.AddCommand(runCmd)
|
||||
}
|
||||
|
|
2
go.mod
2
go.mod
|
@ -45,7 +45,7 @@ require (
|
|||
github.com/imdario/mergo v0.3.13
|
||||
github.com/inconshreveable/mousetrap v1.0.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/spf13/pflag v1.0.5
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
golang.org/x/net v0.1.0 // indirect
|
||||
golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 // indirect
|
||||
|
|
Loading…
Reference in New Issue