From bc30c9ed4a2670bce568379cc4f743360707fd79 Mon Sep 17 00:00:00 2001 From: Andrew Nester Date: Fri, 9 Feb 2024 15:33:14 +0100 Subject: [PATCH] Added `--restart` flag for `bundle run` command (#1191) ## Changes Added `--restart` flag for `bundle run` command When running with this flag, `bundle run` will cancel all existing runs before starting a new one ## Tests Manually --- bundle/run/job.go | 40 +++++++++++++++++++ bundle/run/job_test.go | 79 +++++++++++++++++++++++++++++++++++++ bundle/run/pipeline.go | 15 +++++++ bundle/run/pipeline_test.go | 49 +++++++++++++++++++++++ bundle/run/runner.go | 3 ++ cmd/bundle/run.go | 11 ++++++ 6 files changed, 197 insertions(+) create mode 100644 bundle/run/pipeline_test.go diff --git a/bundle/run/job.go b/bundle/run/job.go index a54279c1..043ea846 100644 --- a/bundle/run/job.go +++ b/bundle/run/job.go @@ -15,6 +15,7 @@ import ( "github.com/databricks/cli/libs/log" "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/fatih/color" + "golang.org/x/sync/errgroup" ) // Default timeout for waiting for a job run to complete. @@ -275,3 +276,42 @@ func (r *jobRunner) convertPythonParams(opts *Options) error { return nil } + +func (r *jobRunner) Cancel(ctx context.Context) error { + w := r.bundle.WorkspaceClient() + jobID, err := strconv.ParseInt(r.job.ID, 10, 64) + if err != nil { + return fmt.Errorf("job ID is not an integer: %s", r.job.ID) + } + + runs, err := w.Jobs.ListRunsAll(ctx, jobs.ListRunsRequest{ + ActiveOnly: true, + JobId: jobID, + }) + + if err != nil { + return err + } + + if len(runs) == 0 { + return nil + } + + errGroup, errCtx := errgroup.WithContext(ctx) + for _, run := range runs { + runId := run.RunId + errGroup.Go(func() error { + wait, err := w.Jobs.CancelRun(errCtx, jobs.CancelRun{ + RunId: runId, + }) + if err != nil { + return err + } + // Waits for the Terminated or Skipped state + _, err = wait.GetWithTimeout(jobRunTimeout) + return err + }) + } + + return errGroup.Wait() +} diff --git a/bundle/run/job_test.go b/bundle/run/job_test.go index e4cb4e7e..be189306 100644 --- a/bundle/run/job_test.go +++ b/bundle/run/job_test.go @@ -1,12 +1,16 @@ package run import ( + "context" "testing" + "time" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config/resources" + "github.com/databricks/databricks-sdk-go/experimental/mocks" "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -47,3 +51,78 @@ func TestConvertPythonParams(t *testing.T) { require.Contains(t, opts.Job.notebookParams, "__python_params") require.Equal(t, opts.Job.notebookParams["__python_params"], `["param1","param2","param3"]`) } + +func TestJobRunnerCancel(t *testing.T) { + job := &resources.Job{ + ID: "123", + } + b := &bundle.Bundle{ + Config: config.Root{ + Resources: config.Resources{ + Jobs: map[string]*resources.Job{ + "test_job": job, + }, + }, + }, + } + + runner := jobRunner{key: "test", bundle: b, job: job} + + m := mocks.NewMockWorkspaceClient(t) + b.SetWorkpaceClient(m.WorkspaceClient) + + jobApi := m.GetMockJobsAPI() + jobApi.EXPECT().ListRunsAll(mock.Anything, jobs.ListRunsRequest{ + ActiveOnly: true, + JobId: 123, + }).Return([]jobs.BaseRun{ + {RunId: 1}, + {RunId: 2}, + }, nil) + + mockWait := &jobs.WaitGetRunJobTerminatedOrSkipped[struct{}]{ + Poll: func(time time.Duration, f func(j *jobs.Run)) (*jobs.Run, error) { + return nil, nil + }, + } + jobApi.EXPECT().CancelRun(mock.Anything, jobs.CancelRun{ + RunId: 1, + }).Return(mockWait, nil) + jobApi.EXPECT().CancelRun(mock.Anything, jobs.CancelRun{ + RunId: 2, + }).Return(mockWait, nil) + + err := runner.Cancel(context.Background()) + require.NoError(t, err) +} + +func TestJobRunnerCancelWithNoActiveRuns(t *testing.T) { + job := &resources.Job{ + ID: "123", + } + b := &bundle.Bundle{ + Config: config.Root{ + Resources: config.Resources{ + Jobs: map[string]*resources.Job{ + "test_job": job, + }, + }, + }, + } + + runner := jobRunner{key: "test", bundle: b, job: job} + + m := mocks.NewMockWorkspaceClient(t) + b.SetWorkpaceClient(m.WorkspaceClient) + + jobApi := m.GetMockJobsAPI() + jobApi.EXPECT().ListRunsAll(mock.Anything, jobs.ListRunsRequest{ + ActiveOnly: true, + JobId: 123, + }).Return([]jobs.BaseRun{}, nil) + + jobApi.AssertNotCalled(t, "CancelRun") + + err := runner.Cancel(context.Background()) + require.NoError(t, err) +} diff --git a/bundle/run/pipeline.go b/bundle/run/pipeline.go index 342a771b..e1f5bfe5 100644 --- a/bundle/run/pipeline.go +++ b/bundle/run/pipeline.go @@ -166,3 +166,18 @@ func (r *pipelineRunner) Run(ctx context.Context, opts *Options) (output.RunOutp time.Sleep(time.Second) } } + +func (r *pipelineRunner) Cancel(ctx context.Context) error { + w := r.bundle.WorkspaceClient() + wait, err := w.Pipelines.Stop(ctx, pipelines.StopRequest{ + PipelineId: r.pipeline.ID, + }) + + if err != nil { + return err + } + + // Waits for the Idle state of the pipeline + _, err = wait.GetWithTimeout(jobRunTimeout) + return err +} diff --git a/bundle/run/pipeline_test.go b/bundle/run/pipeline_test.go new file mode 100644 index 00000000..29b57ffd --- /dev/null +++ b/bundle/run/pipeline_test.go @@ -0,0 +1,49 @@ +package run + +import ( + "context" + "testing" + "time" + + "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/config/resources" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/pipelines" + "github.com/stretchr/testify/require" +) + +func TestPipelineRunnerCancel(t *testing.T) { + pipeline := &resources.Pipeline{ + ID: "123", + } + + b := &bundle.Bundle{ + Config: config.Root{ + Resources: config.Resources{ + Pipelines: map[string]*resources.Pipeline{ + "test_pipeline": pipeline, + }, + }, + }, + } + + runner := pipelineRunner{key: "test", bundle: b, pipeline: pipeline} + + m := mocks.NewMockWorkspaceClient(t) + b.SetWorkpaceClient(m.WorkspaceClient) + + mockWait := &pipelines.WaitGetPipelineIdle[struct{}]{ + Poll: func(time.Duration, func(*pipelines.GetPipelineResponse)) (*pipelines.GetPipelineResponse, error) { + return nil, nil + }, + } + + pipelineApi := m.GetMockPipelinesAPI() + pipelineApi.EXPECT().Stop(context.Background(), pipelines.StopRequest{ + PipelineId: "123", + }).Return(mockWait, nil) + + err := runner.Cancel(context.Background()) + require.NoError(t, err) +} diff --git a/bundle/run/runner.go b/bundle/run/runner.go index 7d3c2c29..de2a1ae7 100644 --- a/bundle/run/runner.go +++ b/bundle/run/runner.go @@ -26,6 +26,9 @@ type Runner interface { // Run the underlying worklow. Run(ctx context.Context, opts *Options) (output.RunOutput, error) + + // Cancel the underlying workflow. + Cancel(ctx context.Context) error } // Find locates a runner matching the specified argument. diff --git a/cmd/bundle/run.go b/cmd/bundle/run.go index a4b10658..c1a8d4ea 100644 --- a/cmd/bundle/run.go +++ b/cmd/bundle/run.go @@ -27,7 +27,9 @@ func newRunCommand() *cobra.Command { runOptions.Define(cmd) var noWait bool + var restart bool cmd.Flags().BoolVar(&noWait, "no-wait", false, "Don't wait for the run to complete.") + cmd.Flags().BoolVar(&restart, "restart", false, "Restart the run if it is already running.") cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() @@ -68,6 +70,15 @@ func newRunCommand() *cobra.Command { } runOptions.NoWait = noWait + if restart { + s := cmdio.Spinner(ctx) + s <- "Cancelling all runs" + err := runner.Cancel(ctx) + close(s) + if err != nil { + return err + } + } output, err := runner.Run(ctx, &runOptions) if err != nil { return err