mirror of https://github.com/databricks/cli.git
Added `--restart` flag for `bundle run` command (#1191)
## Changes Added `--restart` flag for `bundle run` command When running with this flag, `bundle run` will cancel all existing runs before starting a new one ## Tests Manually
This commit is contained in:
parent
cac112c5bc
commit
bc30c9ed4a
|
@ -15,6 +15,7 @@ import (
|
||||||
"github.com/databricks/cli/libs/log"
|
"github.com/databricks/cli/libs/log"
|
||||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||||
"github.com/fatih/color"
|
"github.com/fatih/color"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Default timeout for waiting for a job run to complete.
|
// Default timeout for waiting for a job run to complete.
|
||||||
|
@ -275,3 +276,42 @@ func (r *jobRunner) convertPythonParams(opts *Options) error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *jobRunner) Cancel(ctx context.Context) error {
|
||||||
|
w := r.bundle.WorkspaceClient()
|
||||||
|
jobID, err := strconv.ParseInt(r.job.ID, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("job ID is not an integer: %s", r.job.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
runs, err := w.Jobs.ListRunsAll(ctx, jobs.ListRunsRequest{
|
||||||
|
ActiveOnly: true,
|
||||||
|
JobId: jobID,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(runs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
errGroup, errCtx := errgroup.WithContext(ctx)
|
||||||
|
for _, run := range runs {
|
||||||
|
runId := run.RunId
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
wait, err := w.Jobs.CancelRun(errCtx, jobs.CancelRun{
|
||||||
|
RunId: runId,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Waits for the Terminated or Skipped state
|
||||||
|
_, err = wait.GetWithTimeout(jobRunTimeout)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return errGroup.Wait()
|
||||||
|
}
|
||||||
|
|
|
@ -1,12 +1,16 @@
|
||||||
package run
|
package run
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/databricks/cli/bundle"
|
"github.com/databricks/cli/bundle"
|
||||||
"github.com/databricks/cli/bundle/config"
|
"github.com/databricks/cli/bundle/config"
|
||||||
"github.com/databricks/cli/bundle/config/resources"
|
"github.com/databricks/cli/bundle/config/resources"
|
||||||
|
"github.com/databricks/databricks-sdk-go/experimental/mocks"
|
||||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -47,3 +51,78 @@ func TestConvertPythonParams(t *testing.T) {
|
||||||
require.Contains(t, opts.Job.notebookParams, "__python_params")
|
require.Contains(t, opts.Job.notebookParams, "__python_params")
|
||||||
require.Equal(t, opts.Job.notebookParams["__python_params"], `["param1","param2","param3"]`)
|
require.Equal(t, opts.Job.notebookParams["__python_params"], `["param1","param2","param3"]`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJobRunnerCancel(t *testing.T) {
|
||||||
|
job := &resources.Job{
|
||||||
|
ID: "123",
|
||||||
|
}
|
||||||
|
b := &bundle.Bundle{
|
||||||
|
Config: config.Root{
|
||||||
|
Resources: config.Resources{
|
||||||
|
Jobs: map[string]*resources.Job{
|
||||||
|
"test_job": job,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
runner := jobRunner{key: "test", bundle: b, job: job}
|
||||||
|
|
||||||
|
m := mocks.NewMockWorkspaceClient(t)
|
||||||
|
b.SetWorkpaceClient(m.WorkspaceClient)
|
||||||
|
|
||||||
|
jobApi := m.GetMockJobsAPI()
|
||||||
|
jobApi.EXPECT().ListRunsAll(mock.Anything, jobs.ListRunsRequest{
|
||||||
|
ActiveOnly: true,
|
||||||
|
JobId: 123,
|
||||||
|
}).Return([]jobs.BaseRun{
|
||||||
|
{RunId: 1},
|
||||||
|
{RunId: 2},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
mockWait := &jobs.WaitGetRunJobTerminatedOrSkipped[struct{}]{
|
||||||
|
Poll: func(time time.Duration, f func(j *jobs.Run)) (*jobs.Run, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
jobApi.EXPECT().CancelRun(mock.Anything, jobs.CancelRun{
|
||||||
|
RunId: 1,
|
||||||
|
}).Return(mockWait, nil)
|
||||||
|
jobApi.EXPECT().CancelRun(mock.Anything, jobs.CancelRun{
|
||||||
|
RunId: 2,
|
||||||
|
}).Return(mockWait, nil)
|
||||||
|
|
||||||
|
err := runner.Cancel(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJobRunnerCancelWithNoActiveRuns(t *testing.T) {
|
||||||
|
job := &resources.Job{
|
||||||
|
ID: "123",
|
||||||
|
}
|
||||||
|
b := &bundle.Bundle{
|
||||||
|
Config: config.Root{
|
||||||
|
Resources: config.Resources{
|
||||||
|
Jobs: map[string]*resources.Job{
|
||||||
|
"test_job": job,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
runner := jobRunner{key: "test", bundle: b, job: job}
|
||||||
|
|
||||||
|
m := mocks.NewMockWorkspaceClient(t)
|
||||||
|
b.SetWorkpaceClient(m.WorkspaceClient)
|
||||||
|
|
||||||
|
jobApi := m.GetMockJobsAPI()
|
||||||
|
jobApi.EXPECT().ListRunsAll(mock.Anything, jobs.ListRunsRequest{
|
||||||
|
ActiveOnly: true,
|
||||||
|
JobId: 123,
|
||||||
|
}).Return([]jobs.BaseRun{}, nil)
|
||||||
|
|
||||||
|
jobApi.AssertNotCalled(t, "CancelRun")
|
||||||
|
|
||||||
|
err := runner.Cancel(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
|
@ -166,3 +166,18 @@ func (r *pipelineRunner) Run(ctx context.Context, opts *Options) (output.RunOutp
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *pipelineRunner) Cancel(ctx context.Context) error {
|
||||||
|
w := r.bundle.WorkspaceClient()
|
||||||
|
wait, err := w.Pipelines.Stop(ctx, pipelines.StopRequest{
|
||||||
|
PipelineId: r.pipeline.ID,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Waits for the Idle state of the pipeline
|
||||||
|
_, err = wait.GetWithTimeout(jobRunTimeout)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
package run
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/databricks/cli/bundle"
|
||||||
|
"github.com/databricks/cli/bundle/config"
|
||||||
|
"github.com/databricks/cli/bundle/config/resources"
|
||||||
|
"github.com/databricks/databricks-sdk-go/experimental/mocks"
|
||||||
|
"github.com/databricks/databricks-sdk-go/service/pipelines"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPipelineRunnerCancel(t *testing.T) {
|
||||||
|
pipeline := &resources.Pipeline{
|
||||||
|
ID: "123",
|
||||||
|
}
|
||||||
|
|
||||||
|
b := &bundle.Bundle{
|
||||||
|
Config: config.Root{
|
||||||
|
Resources: config.Resources{
|
||||||
|
Pipelines: map[string]*resources.Pipeline{
|
||||||
|
"test_pipeline": pipeline,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
runner := pipelineRunner{key: "test", bundle: b, pipeline: pipeline}
|
||||||
|
|
||||||
|
m := mocks.NewMockWorkspaceClient(t)
|
||||||
|
b.SetWorkpaceClient(m.WorkspaceClient)
|
||||||
|
|
||||||
|
mockWait := &pipelines.WaitGetPipelineIdle[struct{}]{
|
||||||
|
Poll: func(time.Duration, func(*pipelines.GetPipelineResponse)) (*pipelines.GetPipelineResponse, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pipelineApi := m.GetMockPipelinesAPI()
|
||||||
|
pipelineApi.EXPECT().Stop(context.Background(), pipelines.StopRequest{
|
||||||
|
PipelineId: "123",
|
||||||
|
}).Return(mockWait, nil)
|
||||||
|
|
||||||
|
err := runner.Cancel(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
|
@ -26,6 +26,9 @@ type Runner interface {
|
||||||
|
|
||||||
// Run the underlying worklow.
|
// Run the underlying worklow.
|
||||||
Run(ctx context.Context, opts *Options) (output.RunOutput, error)
|
Run(ctx context.Context, opts *Options) (output.RunOutput, error)
|
||||||
|
|
||||||
|
// Cancel the underlying workflow.
|
||||||
|
Cancel(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find locates a runner matching the specified argument.
|
// Find locates a runner matching the specified argument.
|
||||||
|
|
|
@ -27,7 +27,9 @@ func newRunCommand() *cobra.Command {
|
||||||
runOptions.Define(cmd)
|
runOptions.Define(cmd)
|
||||||
|
|
||||||
var noWait bool
|
var noWait bool
|
||||||
|
var restart bool
|
||||||
cmd.Flags().BoolVar(&noWait, "no-wait", false, "Don't wait for the run to complete.")
|
cmd.Flags().BoolVar(&noWait, "no-wait", false, "Don't wait for the run to complete.")
|
||||||
|
cmd.Flags().BoolVar(&restart, "restart", false, "Restart the run if it is already running.")
|
||||||
|
|
||||||
cmd.RunE = func(cmd *cobra.Command, args []string) error {
|
cmd.RunE = func(cmd *cobra.Command, args []string) error {
|
||||||
ctx := cmd.Context()
|
ctx := cmd.Context()
|
||||||
|
@ -68,6 +70,15 @@ func newRunCommand() *cobra.Command {
|
||||||
}
|
}
|
||||||
|
|
||||||
runOptions.NoWait = noWait
|
runOptions.NoWait = noWait
|
||||||
|
if restart {
|
||||||
|
s := cmdio.Spinner(ctx)
|
||||||
|
s <- "Cancelling all runs"
|
||||||
|
err := runner.Cancel(ctx)
|
||||||
|
close(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
output, err := runner.Run(ctx, &runOptions)
|
output, err := runner.Run(ctx, &runOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
Loading…
Reference in New Issue