Added `--restart` flag for `bundle run` command (#1191)

## Changes
Added `--restart` flag for `bundle run` command

When running with this flag, `bundle run` will cancel all existing runs
before starting a new one

## Tests
Manually
This commit is contained in:
Andrew Nester 2024-02-09 15:33:14 +01:00 committed by GitHub
parent cac112c5bc
commit bc30c9ed4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 197 additions and 0 deletions

View File

@ -15,6 +15,7 @@ import (
"github.com/databricks/cli/libs/log"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/fatih/color"
"golang.org/x/sync/errgroup"
)
// Default timeout for waiting for a job run to complete.
@ -275,3 +276,42 @@ func (r *jobRunner) convertPythonParams(opts *Options) error {
return nil
}
func (r *jobRunner) Cancel(ctx context.Context) error {
w := r.bundle.WorkspaceClient()
jobID, err := strconv.ParseInt(r.job.ID, 10, 64)
if err != nil {
return fmt.Errorf("job ID is not an integer: %s", r.job.ID)
}
runs, err := w.Jobs.ListRunsAll(ctx, jobs.ListRunsRequest{
ActiveOnly: true,
JobId: jobID,
})
if err != nil {
return err
}
if len(runs) == 0 {
return nil
}
errGroup, errCtx := errgroup.WithContext(ctx)
for _, run := range runs {
runId := run.RunId
errGroup.Go(func() error {
wait, err := w.Jobs.CancelRun(errCtx, jobs.CancelRun{
RunId: runId,
})
if err != nil {
return err
}
// Waits for the Terminated or Skipped state
_, err = wait.GetWithTimeout(jobRunTimeout)
return err
})
}
return errGroup.Wait()
}

View File

@ -1,12 +1,16 @@
package run
import (
"context"
"testing"
"time"
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
@ -47,3 +51,78 @@ func TestConvertPythonParams(t *testing.T) {
require.Contains(t, opts.Job.notebookParams, "__python_params")
require.Equal(t, opts.Job.notebookParams["__python_params"], `["param1","param2","param3"]`)
}
func TestJobRunnerCancel(t *testing.T) {
job := &resources.Job{
ID: "123",
}
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"test_job": job,
},
},
},
}
runner := jobRunner{key: "test", bundle: b, job: job}
m := mocks.NewMockWorkspaceClient(t)
b.SetWorkpaceClient(m.WorkspaceClient)
jobApi := m.GetMockJobsAPI()
jobApi.EXPECT().ListRunsAll(mock.Anything, jobs.ListRunsRequest{
ActiveOnly: true,
JobId: 123,
}).Return([]jobs.BaseRun{
{RunId: 1},
{RunId: 2},
}, nil)
mockWait := &jobs.WaitGetRunJobTerminatedOrSkipped[struct{}]{
Poll: func(time time.Duration, f func(j *jobs.Run)) (*jobs.Run, error) {
return nil, nil
},
}
jobApi.EXPECT().CancelRun(mock.Anything, jobs.CancelRun{
RunId: 1,
}).Return(mockWait, nil)
jobApi.EXPECT().CancelRun(mock.Anything, jobs.CancelRun{
RunId: 2,
}).Return(mockWait, nil)
err := runner.Cancel(context.Background())
require.NoError(t, err)
}
func TestJobRunnerCancelWithNoActiveRuns(t *testing.T) {
job := &resources.Job{
ID: "123",
}
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"test_job": job,
},
},
},
}
runner := jobRunner{key: "test", bundle: b, job: job}
m := mocks.NewMockWorkspaceClient(t)
b.SetWorkpaceClient(m.WorkspaceClient)
jobApi := m.GetMockJobsAPI()
jobApi.EXPECT().ListRunsAll(mock.Anything, jobs.ListRunsRequest{
ActiveOnly: true,
JobId: 123,
}).Return([]jobs.BaseRun{}, nil)
jobApi.AssertNotCalled(t, "CancelRun")
err := runner.Cancel(context.Background())
require.NoError(t, err)
}

View File

@ -166,3 +166,18 @@ func (r *pipelineRunner) Run(ctx context.Context, opts *Options) (output.RunOutp
time.Sleep(time.Second)
}
}
func (r *pipelineRunner) Cancel(ctx context.Context) error {
w := r.bundle.WorkspaceClient()
wait, err := w.Pipelines.Stop(ctx, pipelines.StopRequest{
PipelineId: r.pipeline.ID,
})
if err != nil {
return err
}
// Waits for the Idle state of the pipeline
_, err = wait.GetWithTimeout(jobRunTimeout)
return err
}

View File

@ -0,0 +1,49 @@
package run
import (
"context"
"testing"
"time"
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/pipelines"
"github.com/stretchr/testify/require"
)
func TestPipelineRunnerCancel(t *testing.T) {
pipeline := &resources.Pipeline{
ID: "123",
}
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Pipelines: map[string]*resources.Pipeline{
"test_pipeline": pipeline,
},
},
},
}
runner := pipelineRunner{key: "test", bundle: b, pipeline: pipeline}
m := mocks.NewMockWorkspaceClient(t)
b.SetWorkpaceClient(m.WorkspaceClient)
mockWait := &pipelines.WaitGetPipelineIdle[struct{}]{
Poll: func(time.Duration, func(*pipelines.GetPipelineResponse)) (*pipelines.GetPipelineResponse, error) {
return nil, nil
},
}
pipelineApi := m.GetMockPipelinesAPI()
pipelineApi.EXPECT().Stop(context.Background(), pipelines.StopRequest{
PipelineId: "123",
}).Return(mockWait, nil)
err := runner.Cancel(context.Background())
require.NoError(t, err)
}

View File

@ -26,6 +26,9 @@ type Runner interface {
// Run the underlying worklow.
Run(ctx context.Context, opts *Options) (output.RunOutput, error)
// Cancel the underlying workflow.
Cancel(ctx context.Context) error
}
// Find locates a runner matching the specified argument.

View File

@ -27,7 +27,9 @@ func newRunCommand() *cobra.Command {
runOptions.Define(cmd)
var noWait bool
var restart bool
cmd.Flags().BoolVar(&noWait, "no-wait", false, "Don't wait for the run to complete.")
cmd.Flags().BoolVar(&restart, "restart", false, "Restart the run if it is already running.")
cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
@ -68,6 +70,15 @@ func newRunCommand() *cobra.Command {
}
runOptions.NoWait = noWait
if restart {
s := cmdio.Spinner(ctx)
s <- "Cancelling all runs"
err := runner.Cancel(ctx)
close(s)
if err != nil {
return err
}
}
output, err := runner.Run(ctx, &runOptions)
if err != nil {
return err