diff --git a/bundle/config/mutator/override_compute.go b/bundle/config/mutator/override_compute.go index 3ceea4dec..8ce28cbcc 100644 --- a/bundle/config/mutator/override_compute.go +++ b/bundle/config/mutator/override_compute.go @@ -3,6 +3,7 @@ package mutator import ( "context" "fmt" + "os" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" @@ -32,11 +33,17 @@ func overrideJobCompute(j *resources.Job, compute string) { } func (m *overrideCompute) Apply(ctx context.Context, b *bundle.Bundle) error { - if b.Config.Bundle.Compute == "" { + if b.Config.Bundle.Mode != config.Development { + if b.Config.Bundle.Compute != "" { + return fmt.Errorf("cannot override compute for an environment that does not use 'mode: development'") + } return nil } - if b.Config.Bundle.Mode != config.Development { - return fmt.Errorf("cannot override compute for an environment that does not use 'mode: debug'") + if os.Getenv("DATABRICKS_CLUSTER_ID") != "" { + b.Config.Bundle.Compute = os.Getenv("DATABRICKS_CLUSTER_ID") + } + if b.Config.Bundle.Compute == "" { + return nil } r := b.Config.Resources diff --git a/bundle/config/mutator/override_compute_test.go b/bundle/config/mutator/override_compute_test.go index 22fb5a66d..a9da6856f 100644 --- a/bundle/config/mutator/override_compute_test.go +++ b/bundle/config/mutator/override_compute_test.go @@ -2,6 +2,7 @@ package mutator_test import ( "context" + "os" "testing" "github.com/databricks/cli/bundle" @@ -14,7 +15,8 @@ import ( "github.com/stretchr/testify/require" ) -func TestOverrideCompute(t *testing.T) { +func TestOverrideDevelopment(t *testing.T) { + os.Setenv("DATABRICKS_CLUSTER_ID", "") bundle := &bundle.Bundle{ Config: config.Root{ Bundle: config.Bundle{ @@ -46,3 +48,87 @@ func TestOverrideCompute(t *testing.T) { assert.Equal(t, "newClusterID", bundle.Config.Resources.Jobs["job1"].Tasks[0].ExistingClusterId) assert.Equal(t, "newClusterID", bundle.Config.Resources.Jobs["job1"].Tasks[1].ExistingClusterId) } + +func TestOverrideDevelopmentEnv(t *testing.T) { + os.Setenv("DATABRICKS_CLUSTER_ID", "newClusterId") + bundle := &bundle.Bundle{ + Config: config.Root{ + Resources: config.Resources{ + Jobs: map[string]*resources.Job{ + "job1": {JobSettings: &jobs.JobSettings{ + Name: "job1", + Tasks: []jobs.Task{ + { + NewCluster: &compute.ClusterSpec{}, + }, + { + ExistingClusterId: "cluster2", + }, + }, + }}, + }, + }, + }, + } + + m := mutator.OverrideCompute() + err := m.Apply(context.Background(), bundle) + require.NoError(t, err) + assert.Equal(t, "cluster2", bundle.Config.Resources.Jobs["job1"].Tasks[1].ExistingClusterId) +} + +func TestOverrideProduction(t *testing.T) { + bundle := &bundle.Bundle{ + Config: config.Root{ + Bundle: config.Bundle{ + Compute: "newClusterID", + }, + Resources: config.Resources{ + Jobs: map[string]*resources.Job{ + "job1": {JobSettings: &jobs.JobSettings{ + Name: "job1", + Tasks: []jobs.Task{ + { + NewCluster: &compute.ClusterSpec{}, + }, + { + ExistingClusterId: "cluster2", + }, + }, + }}, + }, + }, + }, + } + + m := mutator.OverrideCompute() + err := m.Apply(context.Background(), bundle) + require.Error(t, err) +} + +func TestOverrideProductionEnv(t *testing.T) { + os.Setenv("DATABRICKS_CLUSTER_ID", "newClusterId") + bundle := &bundle.Bundle{ + Config: config.Root{ + Resources: config.Resources{ + Jobs: map[string]*resources.Job{ + "job1": {JobSettings: &jobs.JobSettings{ + Name: "job1", + Tasks: []jobs.Task{ + { + NewCluster: &compute.ClusterSpec{}, + }, + { + ExistingClusterId: "cluster2", + }, + }, + }}, + }, + }, + }, + } + + m := mutator.OverrideCompute() + err := m.Apply(context.Background(), bundle) + require.NoError(t, err) +} diff --git a/cmd/bundle/deploy.go b/cmd/bundle/deploy.go index 0cf961189..05df84c59 100644 --- a/cmd/bundle/deploy.go +++ b/cmd/bundle/deploy.go @@ -1,8 +1,6 @@ package bundle import ( - "os" - "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/phases" "github.com/spf13/cobra" @@ -21,10 +19,6 @@ var deployCmd = &cobra.Command{ } func deploy(cmd *cobra.Command, b *bundle.Bundle) error { - if computeID == "" { - computeID = os.Getenv("DATABRICKS_CLUSTER_ID") - } - // If `--force` is specified, force acquisition of the deployment lock. b.Config.Bundle.Lock.Force = forceDeploy b.Config.Bundle.Compute = computeID