From ee30277119a7770e3c867e3d4fcda1ff03354457 Mon Sep 17 00:00:00 2001 From: Pieter Noordhuis Date: Thu, 21 Sep 2023 21:21:20 +0200 Subject: [PATCH] Enable target overrides for pipeline clusters (#792) ## Changes This is a follow-up to #658 and #779 for jobs. This change applies label normalization the same way the backend does. ## Tests Unit and config loading tests. --- bundle/config/resources.go | 21 +++-- bundle/config/resources/pipeline.go | 50 ++++++++++++ bundle/config/resources/pipeline_test.go | 76 +++++++++++++++++++ bundle/config/root.go | 7 +- .../override_pipeline_cluster/databricks.yml | 33 ++++++++ .../tests/override_pipeline_cluster_test.go | 29 +++++++ 6 files changed, 199 insertions(+), 17 deletions(-) create mode 100644 bundle/config/resources/pipeline_test.go create mode 100644 bundle/tests/override_pipeline_cluster/databricks.yml create mode 100644 bundle/tests/override_pipeline_cluster_test.go diff --git a/bundle/config/resources.go b/bundle/config/resources.go index 48621e443..ad1d6e9a3 100644 --- a/bundle/config/resources.go +++ b/bundle/config/resources.go @@ -131,24 +131,23 @@ func (r *Resources) SetConfigFilePath(path string) { } } -// MergeJobClusters iterates over all jobs and merges their job clusters. -// This is called after applying the target overrides. -func (r *Resources) MergeJobClusters() error { +// Merge iterates over all resources and merges chunks of the +// resource configuration that can be merged. For example, for +// jobs, this merges job cluster definitions and tasks that +// use the same `job_cluster_key`, or `task_key`, respectively. +func (r *Resources) Merge() error { for _, job := range r.Jobs { if err := job.MergeJobClusters(); err != nil { return err } - } - return nil -} - -// MergeTasks iterates over all jobs and merges their tasks. -// This is called after applying the target overrides. -func (r *Resources) MergeTasks() error { - for _, job := range r.Jobs { if err := job.MergeTasks(); err != nil { return err } } + for _, pipeline := range r.Pipelines { + if err := pipeline.MergeClusters(); err != nil { + return err + } + } return nil } diff --git a/bundle/config/resources/pipeline.go b/bundle/config/resources/pipeline.go index d3a51c575..94c0f2b02 100644 --- a/bundle/config/resources/pipeline.go +++ b/bundle/config/resources/pipeline.go @@ -1,8 +1,11 @@ package resources import ( + "strings" + "github.com/databricks/cli/bundle/config/paths" "github.com/databricks/databricks-sdk-go/service/pipelines" + "github.com/imdario/mergo" ) type Pipeline struct { @@ -13,3 +16,50 @@ type Pipeline struct { *pipelines.PipelineSpec } + +// MergeClusters merges cluster definitions with same label. +// The clusters field is a slice, and as such, overrides are appended to it. +// We can identify a cluster by its label, however, so we can use this label +// to figure out which definitions are actually overrides and merge them. +// +// Note: the cluster label is optional and defaults to 'default'. +// We therefore ALSO merge all clusters without a label. +func (p *Pipeline) MergeClusters() error { + clusters := make(map[string]*pipelines.PipelineCluster) + output := make([]pipelines.PipelineCluster, 0, len(p.Clusters)) + + // Normalize cluster labels. + // If empty, this defaults to "default". + // To make matching case insensitive, labels are lowercased. + for i := range p.Clusters { + label := p.Clusters[i].Label + if label == "" { + label = "default" + } + p.Clusters[i].Label = strings.ToLower(label) + } + + // Target overrides are always appended, so we can iterate in natural order to + // first find the base definition, and merge instances we encounter later. + for i := range p.Clusters { + label := p.Clusters[i].Label + + // Register pipeline cluster with label if not yet seen before. + ref, ok := clusters[label] + if !ok { + output = append(output, p.Clusters[i]) + clusters[label] = &output[len(output)-1] + continue + } + + // Merge this instance into the reference. + err := mergo.Merge(ref, &p.Clusters[i], mergo.WithOverride, mergo.WithAppendSlice) + if err != nil { + return err + } + } + + // Overwrite resulting slice. + p.Clusters = output + return nil +} diff --git a/bundle/config/resources/pipeline_test.go b/bundle/config/resources/pipeline_test.go new file mode 100644 index 000000000..316e3d145 --- /dev/null +++ b/bundle/config/resources/pipeline_test.go @@ -0,0 +1,76 @@ +package resources + +import ( + "strings" + "testing" + + "github.com/databricks/databricks-sdk-go/service/pipelines" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPipelineMergeClusters(t *testing.T) { + p := &Pipeline{ + PipelineSpec: &pipelines.PipelineSpec{ + Clusters: []pipelines.PipelineCluster{ + { + NodeTypeId: "i3.xlarge", + NumWorkers: 2, + PolicyId: "1234", + }, + { + Label: "maintenance", + NodeTypeId: "i3.2xlarge", + }, + { + NodeTypeId: "i3.2xlarge", + NumWorkers: 4, + }, + }, + }, + } + + err := p.MergeClusters() + require.NoError(t, err) + + assert.Len(t, p.Clusters, 2) + assert.Equal(t, "default", p.Clusters[0].Label) + assert.Equal(t, "maintenance", p.Clusters[1].Label) + + // The default cluster was merged with a subsequent one. + pc0 := p.Clusters[0] + assert.Equal(t, "i3.2xlarge", pc0.NodeTypeId) + assert.Equal(t, 4, pc0.NumWorkers) + assert.Equal(t, "1234", pc0.PolicyId) + + // The maintenance cluster was left untouched. + pc1 := p.Clusters[1] + assert.Equal(t, "i3.2xlarge", pc1.NodeTypeId) +} + +func TestPipelineMergeClustersCaseInsensitive(t *testing.T) { + p := &Pipeline{ + PipelineSpec: &pipelines.PipelineSpec{ + Clusters: []pipelines.PipelineCluster{ + { + Label: "default", + NumWorkers: 2, + }, + { + Label: "DEFAULT", + NumWorkers: 4, + }, + }, + }, + } + + err := p.MergeClusters() + require.NoError(t, err) + + assert.Len(t, p.Clusters, 1) + + // The default cluster was merged with a subsequent one. + pc0 := p.Clusters[0] + assert.Equal(t, "default", strings.ToLower(pc0.Label)) + assert.Equal(t, 4, pc0.NumWorkers) +} diff --git a/bundle/config/root.go b/bundle/config/root.go index 32883c746..3c79fb0bc 100644 --- a/bundle/config/root.go +++ b/bundle/config/root.go @@ -238,12 +238,7 @@ func (r *Root) MergeTargetOverrides(target *Target) error { return err } - err = r.Resources.MergeJobClusters() - if err != nil { - return err - } - - err = r.Resources.MergeTasks() + err = r.Resources.Merge() if err != nil { return err } diff --git a/bundle/tests/override_pipeline_cluster/databricks.yml b/bundle/tests/override_pipeline_cluster/databricks.yml new file mode 100644 index 000000000..8930f30e8 --- /dev/null +++ b/bundle/tests/override_pipeline_cluster/databricks.yml @@ -0,0 +1,33 @@ +bundle: + name: override_pipeline_cluster + +workspace: + host: https://acme.cloud.databricks.com/ + +resources: + pipelines: + foo: + name: job + clusters: + - label: default + spark_conf: + foo: bar + +targets: + development: + resources: + pipelines: + foo: + clusters: + - label: default + node_type_id: i3.xlarge + num_workers: 1 + + staging: + resources: + pipelines: + foo: + clusters: + - label: default + node_type_id: i3.2xlarge + num_workers: 4 diff --git a/bundle/tests/override_pipeline_cluster_test.go b/bundle/tests/override_pipeline_cluster_test.go new file mode 100644 index 000000000..591fe423d --- /dev/null +++ b/bundle/tests/override_pipeline_cluster_test.go @@ -0,0 +1,29 @@ +package config_tests + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOverridePipelineClusterDev(t *testing.T) { + b := loadTarget(t, "./override_pipeline_cluster", "development") + assert.Equal(t, "job", b.Config.Resources.Pipelines["foo"].Name) + assert.Len(t, b.Config.Resources.Pipelines["foo"].Clusters, 1) + + c := b.Config.Resources.Pipelines["foo"].Clusters[0] + assert.Equal(t, map[string]string{"foo": "bar"}, c.SparkConf) + assert.Equal(t, "i3.xlarge", c.NodeTypeId) + assert.Equal(t, 1, c.NumWorkers) +} + +func TestOverridePipelineClusterStaging(t *testing.T) { + b := loadTarget(t, "./override_pipeline_cluster", "staging") + assert.Equal(t, "job", b.Config.Resources.Pipelines["foo"].Name) + assert.Len(t, b.Config.Resources.Pipelines["foo"].Clusters, 1) + + c := b.Config.Resources.Pipelines["foo"].Clusters[0] + assert.Equal(t, map[string]string{"foo": "bar"}, c.SparkConf) + assert.Equal(t, "i3.2xlarge", c.NodeTypeId) + assert.Equal(t, 4, c.NumWorkers) +}