Fix race condition when restarting continuous jobs (#1849)

## Changes
We don't need to cancel existing runs when the job is continuous and
unpaused. The `/jobs/run-now` command will cancel the existing run and
trigger a new one automatically.

Cancelling the job manually can cause a race condition where both the
manual trigger from the CLI and the continuous trigger from the job
configuration happens at the same time. This PR prevents that from
happening.

## Tests
Unit tests and manually
This commit is contained in:
shreyas-goenka 2024-10-22 20:29:17 +05:30 committed by GitHub
parent 68d69d6e0b
commit 3bab21e72e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 247 additions and 8 deletions

View File

@ -317,6 +317,29 @@ func (r *jobRunner) Cancel(ctx context.Context) error {
return errGroup.Wait() return errGroup.Wait()
} }
func (r *jobRunner) Restart(ctx context.Context, opts *Options) (output.RunOutput, error) {
// We don't need to cancel existing runs if the job is continuous and unpaused.
// the /jobs/run-now API will automatically cancel any existing runs before starting a new one.
//
// /jobs/run-now will not cancel existing runs if the job is continuous and paused.
// New job runs will be queued instead and will wait for existing runs to finish.
// In this case, we need to cancel the existing runs before starting a new one.
continuous := r.job.JobSettings.Continuous
if continuous != nil && continuous.PauseStatus == jobs.PauseStatusUnpaused {
return r.Run(ctx, opts)
}
s := cmdio.Spinner(ctx)
s <- "Cancelling all active job runs"
err := r.Cancel(ctx)
close(s)
if err != nil {
return nil, err
}
return r.Run(ctx, opts)
}
func (r *jobRunner) ParseArgs(args []string, opts *Options) error { func (r *jobRunner) ParseArgs(args []string, opts *Options) error {
return r.posArgsHandler().ParseArgs(args, opts) return r.posArgsHandler().ParseArgs(args, opts)
} }

View File

@ -1,6 +1,7 @@
package run package run
import ( import (
"bytes"
"context" "context"
"testing" "testing"
"time" "time"
@ -8,6 +9,8 @@ import (
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources" "github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/flags"
"github.com/databricks/databricks-sdk-go/experimental/mocks" "github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
@ -126,3 +129,132 @@ func TestJobRunnerCancelWithNoActiveRuns(t *testing.T) {
err := runner.Cancel(context.Background()) err := runner.Cancel(context.Background())
require.NoError(t, err) require.NoError(t, err)
} }
func TestJobRunnerRestart(t *testing.T) {
for _, jobSettings := range []*jobs.JobSettings{
{},
{
Continuous: &jobs.Continuous{
PauseStatus: jobs.PauseStatusPaused,
},
},
} {
job := &resources.Job{
ID: "123",
JobSettings: jobSettings,
}
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"test_job": job,
},
},
},
}
runner := jobRunner{key: "test", bundle: b, job: job}
m := mocks.NewMockWorkspaceClient(t)
b.SetWorkpaceClient(m.WorkspaceClient)
ctx := context.Background()
ctx = cmdio.InContext(ctx, cmdio.NewIO(flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "", ""))
ctx = cmdio.NewContext(ctx, cmdio.NewLogger(flags.ModeAppend))
jobApi := m.GetMockJobsAPI()
jobApi.EXPECT().ListRunsAll(mock.Anything, jobs.ListRunsRequest{
ActiveOnly: true,
JobId: 123,
}).Return([]jobs.BaseRun{
{RunId: 1},
{RunId: 2},
}, nil)
// Mock the runner cancelling existing job runs.
mockWait := &jobs.WaitGetRunJobTerminatedOrSkipped[struct{}]{
Poll: func(time time.Duration, f func(j *jobs.Run)) (*jobs.Run, error) {
return nil, nil
},
}
jobApi.EXPECT().CancelRun(mock.Anything, jobs.CancelRun{
RunId: 1,
}).Return(mockWait, nil)
jobApi.EXPECT().CancelRun(mock.Anything, jobs.CancelRun{
RunId: 2,
}).Return(mockWait, nil)
// Mock the runner triggering a job run
mockWaitForRun := &jobs.WaitGetRunJobTerminatedOrSkipped[jobs.RunNowResponse]{
Poll: func(d time.Duration, f func(*jobs.Run)) (*jobs.Run, error) {
return &jobs.Run{
State: &jobs.RunState{
ResultState: jobs.RunResultStateSuccess,
},
}, nil
},
}
jobApi.EXPECT().RunNow(mock.Anything, jobs.RunNow{
JobId: 123,
}).Return(mockWaitForRun, nil)
// Mock the runner getting the job output
jobApi.EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{}).Return(&jobs.Run{}, nil)
_, err := runner.Restart(ctx, &Options{})
require.NoError(t, err)
}
}
func TestJobRunnerRestartForContinuousUnpausedJobs(t *testing.T) {
job := &resources.Job{
ID: "123",
JobSettings: &jobs.JobSettings{
Continuous: &jobs.Continuous{
PauseStatus: jobs.PauseStatusUnpaused,
},
},
}
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"test_job": job,
},
},
},
}
runner := jobRunner{key: "test", bundle: b, job: job}
m := mocks.NewMockWorkspaceClient(t)
b.SetWorkpaceClient(m.WorkspaceClient)
ctx := context.Background()
ctx = cmdio.InContext(ctx, cmdio.NewIO(flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "", "..."))
ctx = cmdio.NewContext(ctx, cmdio.NewLogger(flags.ModeAppend))
jobApi := m.GetMockJobsAPI()
// The runner should not try and cancel existing job runs for unpaused continuous jobs.
jobApi.AssertNotCalled(t, "ListRunsAll")
jobApi.AssertNotCalled(t, "CancelRun")
// Mock the runner triggering a job run
mockWaitForRun := &jobs.WaitGetRunJobTerminatedOrSkipped[jobs.RunNowResponse]{
Poll: func(d time.Duration, f func(*jobs.Run)) (*jobs.Run, error) {
return &jobs.Run{
State: &jobs.RunState{
ResultState: jobs.RunResultStateSuccess,
},
}, nil
},
}
jobApi.EXPECT().RunNow(mock.Anything, jobs.RunNow{
JobId: 123,
}).Return(mockWaitForRun, nil)
// Mock the runner getting the job output
jobApi.EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{}).Return(&jobs.Run{}, nil)
_, err := runner.Restart(ctx, &Options{})
require.NoError(t, err)
}

View File

@ -183,6 +183,18 @@ func (r *pipelineRunner) Cancel(ctx context.Context) error {
return err return err
} }
func (r *pipelineRunner) Restart(ctx context.Context, opts *Options) (output.RunOutput, error) {
s := cmdio.Spinner(ctx)
s <- "Cancelling the active pipeline update"
err := r.Cancel(ctx)
close(s)
if err != nil {
return nil, err
}
return r.Run(ctx, opts)
}
func (r *pipelineRunner) ParseArgs(args []string, opts *Options) error { func (r *pipelineRunner) ParseArgs(args []string, opts *Options) error {
if len(args) == 0 { if len(args) == 0 {
return nil return nil

View File

@ -1,6 +1,7 @@
package run package run
import ( import (
"bytes"
"context" "context"
"testing" "testing"
"time" "time"
@ -8,8 +9,12 @@ import (
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources" "github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/flags"
sdk_config "github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/experimental/mocks" "github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/pipelines" "github.com/databricks/databricks-sdk-go/service/pipelines"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -47,3 +52,68 @@ func TestPipelineRunnerCancel(t *testing.T) {
err := runner.Cancel(context.Background()) err := runner.Cancel(context.Background())
require.NoError(t, err) require.NoError(t, err)
} }
func TestPipelineRunnerRestart(t *testing.T) {
pipeline := &resources.Pipeline{
ID: "123",
}
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Pipelines: map[string]*resources.Pipeline{
"test_pipeline": pipeline,
},
},
},
}
runner := pipelineRunner{key: "test", bundle: b, pipeline: pipeline}
m := mocks.NewMockWorkspaceClient(t)
m.WorkspaceClient.Config = &sdk_config.Config{
Host: "https://test.com",
}
b.SetWorkpaceClient(m.WorkspaceClient)
ctx := context.Background()
ctx = cmdio.InContext(ctx, cmdio.NewIO(flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "", "..."))
ctx = cmdio.NewContext(ctx, cmdio.NewLogger(flags.ModeAppend))
mockWait := &pipelines.WaitGetPipelineIdle[struct{}]{
Poll: func(time.Duration, func(*pipelines.GetPipelineResponse)) (*pipelines.GetPipelineResponse, error) {
return nil, nil
},
}
pipelineApi := m.GetMockPipelinesAPI()
pipelineApi.EXPECT().Stop(mock.Anything, pipelines.StopRequest{
PipelineId: "123",
}).Return(mockWait, nil)
pipelineApi.EXPECT().GetByPipelineId(mock.Anything, "123").Return(&pipelines.GetPipelineResponse{}, nil)
// Mock runner starting a new update
pipelineApi.EXPECT().StartUpdate(mock.Anything, pipelines.StartUpdate{
PipelineId: "123",
}).Return(&pipelines.StartUpdateResponse{
UpdateId: "456",
}, nil)
// Mock runner polling for events
pipelineApi.EXPECT().ListPipelineEventsAll(mock.Anything, pipelines.ListPipelineEventsRequest{
Filter: `update_id = '456'`,
MaxResults: 100,
PipelineId: "123",
}).Return([]pipelines.PipelineEvent{}, nil)
// Mock runner polling for update status
pipelineApi.EXPECT().GetUpdateByPipelineIdAndUpdateId(mock.Anything, "123", "456").
Return(&pipelines.GetUpdateResponse{
Update: &pipelines.UpdateInfo{
State: pipelines.UpdateInfoStateCompleted,
},
}, nil)
_, err := runner.Restart(ctx, &Options{})
require.NoError(t, err)
}

View File

@ -27,6 +27,10 @@ type Runner interface {
// Run the underlying worklow. // Run the underlying worklow.
Run(ctx context.Context, opts *Options) (output.RunOutput, error) Run(ctx context.Context, opts *Options) (output.RunOutput, error)
// Restart the underlying workflow by cancelling any existing runs before
// starting a new one.
Restart(ctx context.Context, opts *Options) (output.RunOutput, error)
// Cancel the underlying workflow. // Cancel the underlying workflow.
Cancel(ctx context.Context) error Cancel(ctx context.Context) error

View File

@ -8,6 +8,7 @@ import (
"github.com/databricks/cli/bundle/deploy/terraform" "github.com/databricks/cli/bundle/deploy/terraform"
"github.com/databricks/cli/bundle/phases" "github.com/databricks/cli/bundle/phases"
"github.com/databricks/cli/bundle/run" "github.com/databricks/cli/bundle/run"
"github.com/databricks/cli/bundle/run/output"
"github.com/databricks/cli/cmd/bundle/utils" "github.com/databricks/cli/cmd/bundle/utils"
"github.com/databricks/cli/cmd/root" "github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
@ -100,19 +101,16 @@ task or a Python wheel task, the second example applies.
} }
runOptions.NoWait = noWait runOptions.NoWait = noWait
var output output.RunOutput
if restart { if restart {
s := cmdio.Spinner(ctx) output, err = runner.Restart(ctx, &runOptions)
s <- "Cancelling all runs" } else {
err := runner.Cancel(ctx) output, err = runner.Run(ctx, &runOptions)
close(s) }
if err != nil {
return err
}
}
output, err := runner.Run(ctx, &runOptions)
if err != nil { if err != nil {
return err return err
} }
if output != nil { if output != nil {
switch root.OutputType(cmd) { switch root.OutputType(cmd) {
case flags.OutputText: case flags.OutputText: