Enable target overrides for pipeline clusters (#792)

## Changes

This is a follow-up to #658 and #779 for jobs.

This change applies label normalization the same way the backend does.

## Tests

Unit and config loading tests.
This commit is contained in:
Pieter Noordhuis 2023-09-21 21:21:20 +02:00 committed by GitHub
parent c65e59751b
commit ee30277119
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 199 additions and 17 deletions

View File

@ -131,24 +131,23 @@ func (r *Resources) SetConfigFilePath(path string) {
}
}
// MergeJobClusters iterates over all jobs and merges their job clusters.
// This is called after applying the target overrides.
func (r *Resources) MergeJobClusters() error {
// Merge iterates over all resources and merges chunks of the
// resource configuration that can be merged. For example, for
// jobs, this merges job cluster definitions and tasks that
// use the same `job_cluster_key`, or `task_key`, respectively.
func (r *Resources) Merge() error {
for _, job := range r.Jobs {
if err := job.MergeJobClusters(); err != nil {
return err
}
}
return nil
}
// MergeTasks iterates over all jobs and merges their tasks.
// This is called after applying the target overrides.
func (r *Resources) MergeTasks() error {
for _, job := range r.Jobs {
if err := job.MergeTasks(); err != nil {
return err
}
}
for _, pipeline := range r.Pipelines {
if err := pipeline.MergeClusters(); err != nil {
return err
}
}
return nil
}

View File

@ -1,8 +1,11 @@
package resources
import (
"strings"
"github.com/databricks/cli/bundle/config/paths"
"github.com/databricks/databricks-sdk-go/service/pipelines"
"github.com/imdario/mergo"
)
type Pipeline struct {
@ -13,3 +16,50 @@ type Pipeline struct {
*pipelines.PipelineSpec
}
// MergeClusters merges cluster definitions with same label.
// The clusters field is a slice, and as such, overrides are appended to it.
// We can identify a cluster by its label, however, so we can use this label
// to figure out which definitions are actually overrides and merge them.
//
// Note: the cluster label is optional and defaults to 'default'.
// We therefore ALSO merge all clusters without a label.
func (p *Pipeline) MergeClusters() error {
clusters := make(map[string]*pipelines.PipelineCluster)
output := make([]pipelines.PipelineCluster, 0, len(p.Clusters))
// Normalize cluster labels.
// If empty, this defaults to "default".
// To make matching case insensitive, labels are lowercased.
for i := range p.Clusters {
label := p.Clusters[i].Label
if label == "" {
label = "default"
}
p.Clusters[i].Label = strings.ToLower(label)
}
// Target overrides are always appended, so we can iterate in natural order to
// first find the base definition, and merge instances we encounter later.
for i := range p.Clusters {
label := p.Clusters[i].Label
// Register pipeline cluster with label if not yet seen before.
ref, ok := clusters[label]
if !ok {
output = append(output, p.Clusters[i])
clusters[label] = &output[len(output)-1]
continue
}
// Merge this instance into the reference.
err := mergo.Merge(ref, &p.Clusters[i], mergo.WithOverride, mergo.WithAppendSlice)
if err != nil {
return err
}
}
// Overwrite resulting slice.
p.Clusters = output
return nil
}

View File

@ -0,0 +1,76 @@
package resources
import (
"strings"
"testing"
"github.com/databricks/databricks-sdk-go/service/pipelines"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPipelineMergeClusters(t *testing.T) {
p := &Pipeline{
PipelineSpec: &pipelines.PipelineSpec{
Clusters: []pipelines.PipelineCluster{
{
NodeTypeId: "i3.xlarge",
NumWorkers: 2,
PolicyId: "1234",
},
{
Label: "maintenance",
NodeTypeId: "i3.2xlarge",
},
{
NodeTypeId: "i3.2xlarge",
NumWorkers: 4,
},
},
},
}
err := p.MergeClusters()
require.NoError(t, err)
assert.Len(t, p.Clusters, 2)
assert.Equal(t, "default", p.Clusters[0].Label)
assert.Equal(t, "maintenance", p.Clusters[1].Label)
// The default cluster was merged with a subsequent one.
pc0 := p.Clusters[0]
assert.Equal(t, "i3.2xlarge", pc0.NodeTypeId)
assert.Equal(t, 4, pc0.NumWorkers)
assert.Equal(t, "1234", pc0.PolicyId)
// The maintenance cluster was left untouched.
pc1 := p.Clusters[1]
assert.Equal(t, "i3.2xlarge", pc1.NodeTypeId)
}
func TestPipelineMergeClustersCaseInsensitive(t *testing.T) {
p := &Pipeline{
PipelineSpec: &pipelines.PipelineSpec{
Clusters: []pipelines.PipelineCluster{
{
Label: "default",
NumWorkers: 2,
},
{
Label: "DEFAULT",
NumWorkers: 4,
},
},
},
}
err := p.MergeClusters()
require.NoError(t, err)
assert.Len(t, p.Clusters, 1)
// The default cluster was merged with a subsequent one.
pc0 := p.Clusters[0]
assert.Equal(t, "default", strings.ToLower(pc0.Label))
assert.Equal(t, 4, pc0.NumWorkers)
}

View File

@ -238,12 +238,7 @@ func (r *Root) MergeTargetOverrides(target *Target) error {
return err
}
err = r.Resources.MergeJobClusters()
if err != nil {
return err
}
err = r.Resources.MergeTasks()
err = r.Resources.Merge()
if err != nil {
return err
}

View File

@ -0,0 +1,33 @@
bundle:
name: override_pipeline_cluster
workspace:
host: https://acme.cloud.databricks.com/
resources:
pipelines:
foo:
name: job
clusters:
- label: default
spark_conf:
foo: bar
targets:
development:
resources:
pipelines:
foo:
clusters:
- label: default
node_type_id: i3.xlarge
num_workers: 1
staging:
resources:
pipelines:
foo:
clusters:
- label: default
node_type_id: i3.2xlarge
num_workers: 4

View File

@ -0,0 +1,29 @@
package config_tests
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestOverridePipelineClusterDev(t *testing.T) {
b := loadTarget(t, "./override_pipeline_cluster", "development")
assert.Equal(t, "job", b.Config.Resources.Pipelines["foo"].Name)
assert.Len(t, b.Config.Resources.Pipelines["foo"].Clusters, 1)
c := b.Config.Resources.Pipelines["foo"].Clusters[0]
assert.Equal(t, map[string]string{"foo": "bar"}, c.SparkConf)
assert.Equal(t, "i3.xlarge", c.NodeTypeId)
assert.Equal(t, 1, c.NumWorkers)
}
func TestOverridePipelineClusterStaging(t *testing.T) {
b := loadTarget(t, "./override_pipeline_cluster", "staging")
assert.Equal(t, "job", b.Config.Resources.Pipelines["foo"].Name)
assert.Len(t, b.Config.Resources.Pipelines["foo"].Clusters, 1)
c := b.Config.Resources.Pipelines["foo"].Clusters[0]
assert.Equal(t, map[string]string{"foo": "bar"}, c.SparkConf)
assert.Equal(t, "i3.2xlarge", c.NodeTypeId)
assert.Equal(t, 4, c.NumWorkers)
}