mirror of https://github.com/databricks/cli.git
Added `--restart` flag for `bundle run` command (#1191)
## Changes Added `--restart` flag for `bundle run` command When running with this flag, `bundle run` will cancel all existing runs before starting a new one ## Tests Manually
This commit is contained in:
parent
cac112c5bc
commit
bc30c9ed4a
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/databricks/cli/libs/log"
|
||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||
"github.com/fatih/color"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// Default timeout for waiting for a job run to complete.
|
||||
|
@ -275,3 +276,42 @@ func (r *jobRunner) convertPythonParams(opts *Options) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *jobRunner) Cancel(ctx context.Context) error {
|
||||
w := r.bundle.WorkspaceClient()
|
||||
jobID, err := strconv.ParseInt(r.job.ID, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("job ID is not an integer: %s", r.job.ID)
|
||||
}
|
||||
|
||||
runs, err := w.Jobs.ListRunsAll(ctx, jobs.ListRunsRequest{
|
||||
ActiveOnly: true,
|
||||
JobId: jobID,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(runs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
errGroup, errCtx := errgroup.WithContext(ctx)
|
||||
for _, run := range runs {
|
||||
runId := run.RunId
|
||||
errGroup.Go(func() error {
|
||||
wait, err := w.Jobs.CancelRun(errCtx, jobs.CancelRun{
|
||||
RunId: runId,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Waits for the Terminated or Skipped state
|
||||
_, err = wait.GetWithTimeout(jobRunTimeout)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
return errGroup.Wait()
|
||||
}
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
package run
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/databricks/cli/bundle"
|
||||
"github.com/databricks/cli/bundle/config"
|
||||
"github.com/databricks/cli/bundle/config/resources"
|
||||
"github.com/databricks/databricks-sdk-go/experimental/mocks"
|
||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -47,3 +51,78 @@ func TestConvertPythonParams(t *testing.T) {
|
|||
require.Contains(t, opts.Job.notebookParams, "__python_params")
|
||||
require.Equal(t, opts.Job.notebookParams["__python_params"], `["param1","param2","param3"]`)
|
||||
}
|
||||
|
||||
func TestJobRunnerCancel(t *testing.T) {
|
||||
job := &resources.Job{
|
||||
ID: "123",
|
||||
}
|
||||
b := &bundle.Bundle{
|
||||
Config: config.Root{
|
||||
Resources: config.Resources{
|
||||
Jobs: map[string]*resources.Job{
|
||||
"test_job": job,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
runner := jobRunner{key: "test", bundle: b, job: job}
|
||||
|
||||
m := mocks.NewMockWorkspaceClient(t)
|
||||
b.SetWorkpaceClient(m.WorkspaceClient)
|
||||
|
||||
jobApi := m.GetMockJobsAPI()
|
||||
jobApi.EXPECT().ListRunsAll(mock.Anything, jobs.ListRunsRequest{
|
||||
ActiveOnly: true,
|
||||
JobId: 123,
|
||||
}).Return([]jobs.BaseRun{
|
||||
{RunId: 1},
|
||||
{RunId: 2},
|
||||
}, nil)
|
||||
|
||||
mockWait := &jobs.WaitGetRunJobTerminatedOrSkipped[struct{}]{
|
||||
Poll: func(time time.Duration, f func(j *jobs.Run)) (*jobs.Run, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
jobApi.EXPECT().CancelRun(mock.Anything, jobs.CancelRun{
|
||||
RunId: 1,
|
||||
}).Return(mockWait, nil)
|
||||
jobApi.EXPECT().CancelRun(mock.Anything, jobs.CancelRun{
|
||||
RunId: 2,
|
||||
}).Return(mockWait, nil)
|
||||
|
||||
err := runner.Cancel(context.Background())
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestJobRunnerCancelWithNoActiveRuns(t *testing.T) {
|
||||
job := &resources.Job{
|
||||
ID: "123",
|
||||
}
|
||||
b := &bundle.Bundle{
|
||||
Config: config.Root{
|
||||
Resources: config.Resources{
|
||||
Jobs: map[string]*resources.Job{
|
||||
"test_job": job,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
runner := jobRunner{key: "test", bundle: b, job: job}
|
||||
|
||||
m := mocks.NewMockWorkspaceClient(t)
|
||||
b.SetWorkpaceClient(m.WorkspaceClient)
|
||||
|
||||
jobApi := m.GetMockJobsAPI()
|
||||
jobApi.EXPECT().ListRunsAll(mock.Anything, jobs.ListRunsRequest{
|
||||
ActiveOnly: true,
|
||||
JobId: 123,
|
||||
}).Return([]jobs.BaseRun{}, nil)
|
||||
|
||||
jobApi.AssertNotCalled(t, "CancelRun")
|
||||
|
||||
err := runner.Cancel(context.Background())
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
|
|
@ -166,3 +166,18 @@ func (r *pipelineRunner) Run(ctx context.Context, opts *Options) (output.RunOutp
|
|||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *pipelineRunner) Cancel(ctx context.Context) error {
|
||||
w := r.bundle.WorkspaceClient()
|
||||
wait, err := w.Pipelines.Stop(ctx, pipelines.StopRequest{
|
||||
PipelineId: r.pipeline.ID,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Waits for the Idle state of the pipeline
|
||||
_, err = wait.GetWithTimeout(jobRunTimeout)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
package run
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/databricks/cli/bundle"
|
||||
"github.com/databricks/cli/bundle/config"
|
||||
"github.com/databricks/cli/bundle/config/resources"
|
||||
"github.com/databricks/databricks-sdk-go/experimental/mocks"
|
||||
"github.com/databricks/databricks-sdk-go/service/pipelines"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPipelineRunnerCancel(t *testing.T) {
|
||||
pipeline := &resources.Pipeline{
|
||||
ID: "123",
|
||||
}
|
||||
|
||||
b := &bundle.Bundle{
|
||||
Config: config.Root{
|
||||
Resources: config.Resources{
|
||||
Pipelines: map[string]*resources.Pipeline{
|
||||
"test_pipeline": pipeline,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
runner := pipelineRunner{key: "test", bundle: b, pipeline: pipeline}
|
||||
|
||||
m := mocks.NewMockWorkspaceClient(t)
|
||||
b.SetWorkpaceClient(m.WorkspaceClient)
|
||||
|
||||
mockWait := &pipelines.WaitGetPipelineIdle[struct{}]{
|
||||
Poll: func(time.Duration, func(*pipelines.GetPipelineResponse)) (*pipelines.GetPipelineResponse, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
pipelineApi := m.GetMockPipelinesAPI()
|
||||
pipelineApi.EXPECT().Stop(context.Background(), pipelines.StopRequest{
|
||||
PipelineId: "123",
|
||||
}).Return(mockWait, nil)
|
||||
|
||||
err := runner.Cancel(context.Background())
|
||||
require.NoError(t, err)
|
||||
}
|
|
@ -26,6 +26,9 @@ type Runner interface {
|
|||
|
||||
// Run the underlying worklow.
|
||||
Run(ctx context.Context, opts *Options) (output.RunOutput, error)
|
||||
|
||||
// Cancel the underlying workflow.
|
||||
Cancel(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Find locates a runner matching the specified argument.
|
||||
|
|
|
@ -27,7 +27,9 @@ func newRunCommand() *cobra.Command {
|
|||
runOptions.Define(cmd)
|
||||
|
||||
var noWait bool
|
||||
var restart bool
|
||||
cmd.Flags().BoolVar(&noWait, "no-wait", false, "Don't wait for the run to complete.")
|
||||
cmd.Flags().BoolVar(&restart, "restart", false, "Restart the run if it is already running.")
|
||||
|
||||
cmd.RunE = func(cmd *cobra.Command, args []string) error {
|
||||
ctx := cmd.Context()
|
||||
|
@ -68,6 +70,15 @@ func newRunCommand() *cobra.Command {
|
|||
}
|
||||
|
||||
runOptions.NoWait = noWait
|
||||
if restart {
|
||||
s := cmdio.Spinner(ctx)
|
||||
s <- "Cancelling all runs"
|
||||
err := runner.Cancel(ctx)
|
||||
close(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
output, err := runner.Run(ctx, &runOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
Loading…
Reference in New Issue