diff --git a/bundle/config/root.go b/bundle/config/root.go index 60faba29..2bbb7869 100644 --- a/bundle/config/root.go +++ b/bundle/config/root.go @@ -338,13 +338,36 @@ func (r *Root) MergeTargetOverrides(name string) error { "resources", "sync", "permissions", - "variables", } { if root, err = mergeField(root, target, f); err != nil { return err } } + // Merge `variables`. This field must be overwritten if set, not merged. + if v := target.Get("variables"); v.Kind() != dyn.KindInvalid { + _, err = dyn.Map(v, ".", dyn.Foreach(func(p dyn.Path, variable dyn.Value) (dyn.Value, error) { + varPath := dyn.MustPathFromString("variables").Append(p...) + + vDefault := variable.Get("default") + if vDefault.Kind() != dyn.KindInvalid { + defaultPath := varPath.Append(dyn.Key("default")) + root, err = dyn.SetByPath(root, defaultPath, vDefault) + } + + vLookup := variable.Get("lookup") + if vLookup.Kind() != dyn.KindInvalid { + lookupPath := varPath.Append(dyn.Key("lookup")) + root, err = dyn.SetByPath(root, lookupPath, vLookup) + } + + return root, err + })) + if err != nil { + return err + } + } + // Merge `run_as`. This field must be overwritten if set, not merged. if v := target.Get("run_as"); v.Kind() != dyn.KindInvalid { root, err = dyn.Set(root, "run_as", v) @@ -444,6 +467,7 @@ func rewriteShorthands(v dyn.Value) (dyn.Value, error) { if typeV.MustString() == "complex" { return dyn.NewValue(map[string]dyn.Value{ + "type": typeV, "default": variable, }, variable.Location()), nil } diff --git a/bundle/config/root_test.go b/bundle/config/root_test.go index 27cc3d22..aed670d6 100644 --- a/bundle/config/root_test.go +++ b/bundle/config/root_test.go @@ -132,3 +132,56 @@ func TestInitializeComplexVariablesViaFlagIsNotAllowed(t *testing.T) { err := root.InitializeVariables([]string{"foo=123"}) assert.ErrorContains(t, err, "setting variables of complex type via --var flag is not supported: foo") } + +func TestRootMergeTargetOverridesWithVariables(t *testing.T) { + root := &Root{ + Bundle: Bundle{}, + Variables: map[string]*variable.Variable{ + "foo": { + Default: "foo", + Description: "foo var", + }, + "foo2": { + Default: "foo2", + Description: "foo2 var", + }, + "complex": { + Type: variable.VariableTypeComplex, + Description: "complex var", + Default: map[string]interface{}{ + "key": "value", + }, + }, + }, + Targets: map[string]*Target{ + "development": { + Variables: map[string]*variable.Variable{ + "foo": { + Default: "bar", + Description: "wrong", + }, + "complex": { + Type: "wrong", + Description: "wrong", + Default: map[string]interface{}{ + "key1": "value1", + }, + }, + }, + }, + }, + } + root.initializeDynamicValue() + require.NoError(t, root.MergeTargetOverrides("development")) + assert.Equal(t, "bar", root.Variables["foo"].Default) + assert.Equal(t, "foo var", root.Variables["foo"].Description) + + assert.Equal(t, "foo2", root.Variables["foo2"].Default) + assert.Equal(t, "foo2 var", root.Variables["foo2"].Description) + + assert.Equal(t, map[string]interface{}{ + "key1": "value1", + }, root.Variables["complex"].Default) + assert.Equal(t, "complex var", root.Variables["complex"].Description) + +} diff --git a/bundle/tests/complex_variables_test.go b/bundle/tests/complex_variables_test.go index ffe80e41..1badea6d 100644 --- a/bundle/tests/complex_variables_test.go +++ b/bundle/tests/complex_variables_test.go @@ -25,8 +25,10 @@ func TestComplexVariables(t *testing.T) { require.Equal(t, "13.2.x-scala2.11", b.Config.Resources.Jobs["my_job"].JobClusters[0].NewCluster.SparkVersion) require.Equal(t, "Standard_DS3_v2", b.Config.Resources.Jobs["my_job"].JobClusters[0].NewCluster.NodeTypeId) + require.Equal(t, "some-policy-id", b.Config.Resources.Jobs["my_job"].JobClusters[0].NewCluster.PolicyId) require.Equal(t, 2, b.Config.Resources.Jobs["my_job"].JobClusters[0].NewCluster.NumWorkers) require.Equal(t, "true", b.Config.Resources.Jobs["my_job"].JobClusters[0].NewCluster.SparkConf["spark.speculation"]) + require.Equal(t, "true", b.Config.Resources.Jobs["my_job"].JobClusters[0].NewCluster.SparkConf["spark.random"]) require.Equal(t, 3, len(b.Config.Resources.Jobs["my_job"].Tasks[0].Libraries)) require.Contains(t, b.Config.Resources.Jobs["my_job"].Tasks[0].Libraries, compute.Library{ @@ -59,4 +61,10 @@ func TestComplexVariablesOverride(t *testing.T) { require.Equal(t, "Standard_DS3_v3", b.Config.Resources.Jobs["my_job"].JobClusters[0].NewCluster.NodeTypeId) require.Equal(t, 4, b.Config.Resources.Jobs["my_job"].JobClusters[0].NewCluster.NumWorkers) require.Equal(t, "false", b.Config.Resources.Jobs["my_job"].JobClusters[0].NewCluster.SparkConf["spark.speculation"]) + + // Making sure the variable is overriden and not merged / extended + // These properties are set in the default target but not set in override target + // So they should be empty + require.Equal(t, "", b.Config.Resources.Jobs["my_job"].JobClusters[0].NewCluster.SparkConf["spark.random"]) + require.Equal(t, "", b.Config.Resources.Jobs["my_job"].JobClusters[0].NewCluster.PolicyId) } diff --git a/bundle/tests/variables/complex/databricks.yml b/bundle/tests/variables/complex/databricks.yml index f7535ad4..ca27f606 100644 --- a/bundle/tests/variables/complex/databricks.yml +++ b/bundle/tests/variables/complex/databricks.yml @@ -23,9 +23,11 @@ variables: spark_version: "13.2.x-scala2.11" node_type_id: ${var.node_type} num_workers: 2 + policy_id: "some-policy-id" spark_conf: spark.speculation: true spark.databricks.delta.retentionDurationCheck.enabled: false + spark.random: true libraries: type: complex description: "A libraries definition"