From a2a4948047e7e119ead809f59c80299900aeda32 Mon Sep 17 00:00:00 2001 From: Pieter Noordhuis Date: Mon, 19 Feb 2024 11:44:51 +0100 Subject: [PATCH] Allow use of variables references in primitive non-string fields (#1219) ## Changes This change enables the use of bundle variables for boolean, integer, and floating point fields. ## Tests * Unit tests. * I ran a manual test to confirm parameterizing the number of workers in a cluster definition works. --- .../mutator/resolve_variable_references.go | 16 ++- .../resolve_variable_references_test.go | 97 +++++++++++++++++++ libs/dyn/convert/from_typed.go | 16 +++ libs/dyn/convert/from_typed_test.go | 24 +++++ libs/dyn/convert/normalize.go | 16 +++ libs/dyn/convert/normalize_test.go | 24 +++++ libs/dyn/convert/to_typed.go | 16 +++ libs/dyn/convert/to_typed_test.go | 30 +++++- libs/dyn/dynvar/ref.go | 4 + libs/dyn/dynvar/ref_test.go | 7 ++ 10 files changed, 248 insertions(+), 2 deletions(-) diff --git a/bundle/config/mutator/resolve_variable_references.go b/bundle/config/mutator/resolve_variable_references.go index a9ff70f6..1075e83e 100644 --- a/bundle/config/mutator/resolve_variable_references.go +++ b/bundle/config/mutator/resolve_variable_references.go @@ -7,6 +7,7 @@ import ( "github.com/databricks/cli/libs/dyn" "github.com/databricks/cli/libs/dyn/convert" "github.com/databricks/cli/libs/dyn/dynvar" + "github.com/databricks/cli/libs/log" ) type resolveVariableReferences struct { @@ -58,7 +59,7 @@ func (m *resolveVariableReferences) Apply(ctx context.Context, b *bundle.Bundle) } // Resolve variable references in all values. - return dynvar.Resolve(root, func(path dyn.Path) (dyn.Value, error) { + root, err := dynvar.Resolve(root, func(path dyn.Path) (dyn.Value, error) { // Rewrite the shorthand path ${var.foo} into ${variables.foo.value}. if path.HasPrefix(varPath) && len(path) == 2 { path = dyn.NewPath( @@ -77,5 +78,18 @@ func (m *resolveVariableReferences) Apply(ctx context.Context, b *bundle.Bundle) return dyn.InvalidValue, dynvar.ErrSkipResolution }) + if err != nil { + return dyn.InvalidValue, err + } + + // Normalize the result because variable resolution may have been applied to non-string fields. + // For example, a variable reference may have been resolved to a integer. + root, diags := convert.Normalize(b.Config, root) + for _, diag := range diags { + // This occurs when a variable's resolved value is incompatible with the field's type. + // Log a warning until we have a better way to surface these diagnostics to the user. + log.Warnf(ctx, "normalization diagnostic: %s", diag.Summary) + } + return root, nil }) } diff --git a/bundle/config/mutator/resolve_variable_references_test.go b/bundle/config/mutator/resolve_variable_references_test.go index 1f253d41..8190c360 100644 --- a/bundle/config/mutator/resolve_variable_references_test.go +++ b/bundle/config/mutator/resolve_variable_references_test.go @@ -8,7 +8,10 @@ import ( "github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config/resources" "github.com/databricks/cli/bundle/config/variable" + "github.com/databricks/cli/libs/dyn" + "github.com/databricks/databricks-sdk-go/service/compute" "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -95,3 +98,97 @@ func TestResolveVariableReferencesToEmptyFields(t *testing.T) { // The job settings should have been interpolated to an empty string. require.Equal(t, "", b.Config.Resources.Jobs["job1"].JobSettings.Tags["git_branch"]) } + +func TestResolveVariableReferencesForPrimitiveNonStringFields(t *testing.T) { + var err error + + b := &bundle.Bundle{ + Config: config.Root{ + Variables: map[string]*variable.Variable{ + "no_alert_for_canceled_runs": {}, + "no_alert_for_skipped_runs": {}, + "min_workers": {}, + "max_workers": {}, + "spot_bid_max_price": {}, + }, + Resources: config.Resources{ + Jobs: map[string]*resources.Job{ + "job1": { + JobSettings: &jobs.JobSettings{ + NotificationSettings: &jobs.JobNotificationSettings{ + NoAlertForCanceledRuns: false, + NoAlertForSkippedRuns: false, + }, + Tasks: []jobs.Task{ + { + NewCluster: &compute.ClusterSpec{ + Autoscale: &compute.AutoScale{ + MinWorkers: 0, + MaxWorkers: 0, + }, + AzureAttributes: &compute.AzureAttributes{ + SpotBidMaxPrice: 0.0, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + ctx := context.Background() + + // Initialize the variables. + err = bundle.ApplyFunc(ctx, b, func(ctx context.Context, b *bundle.Bundle) error { + return b.Config.InitializeVariables([]string{ + "no_alert_for_canceled_runs=true", + "no_alert_for_skipped_runs=true", + "min_workers=1", + "max_workers=2", + "spot_bid_max_price=0.5", + }) + }) + require.NoError(t, err) + + // Assign the variables to the dynamic configuration. + err = bundle.ApplyFunc(ctx, b, func(ctx context.Context, b *bundle.Bundle) error { + return b.Config.Mutate(func(v dyn.Value) (dyn.Value, error) { + var p dyn.Path + var err error + + // Set the notification settings. + p = dyn.MustPathFromString("resources.jobs.job1.notification_settings") + v, err = dyn.SetByPath(v, p.Append(dyn.Key("no_alert_for_canceled_runs")), dyn.V("${var.no_alert_for_canceled_runs}")) + require.NoError(t, err) + v, err = dyn.SetByPath(v, p.Append(dyn.Key("no_alert_for_skipped_runs")), dyn.V("${var.no_alert_for_skipped_runs}")) + require.NoError(t, err) + + // Set the min and max workers. + p = dyn.MustPathFromString("resources.jobs.job1.tasks[0].new_cluster.autoscale") + v, err = dyn.SetByPath(v, p.Append(dyn.Key("min_workers")), dyn.V("${var.min_workers}")) + require.NoError(t, err) + v, err = dyn.SetByPath(v, p.Append(dyn.Key("max_workers")), dyn.V("${var.max_workers}")) + require.NoError(t, err) + + // Set the spot bid max price. + p = dyn.MustPathFromString("resources.jobs.job1.tasks[0].new_cluster.azure_attributes") + v, err = dyn.SetByPath(v, p.Append(dyn.Key("spot_bid_max_price")), dyn.V("${var.spot_bid_max_price}")) + require.NoError(t, err) + + return v, nil + }) + }) + require.NoError(t, err) + + // Apply for the variable prefix. This should resolve the variables to their values. + err = bundle.Apply(context.Background(), b, ResolveVariableReferences("variables")) + require.NoError(t, err) + assert.Equal(t, true, b.Config.Resources.Jobs["job1"].JobSettings.NotificationSettings.NoAlertForCanceledRuns) + assert.Equal(t, true, b.Config.Resources.Jobs["job1"].JobSettings.NotificationSettings.NoAlertForSkippedRuns) + assert.Equal(t, 1, b.Config.Resources.Jobs["job1"].JobSettings.Tasks[0].NewCluster.Autoscale.MinWorkers) + assert.Equal(t, 2, b.Config.Resources.Jobs["job1"].JobSettings.Tasks[0].NewCluster.Autoscale.MaxWorkers) + assert.Equal(t, 0.5, b.Config.Resources.Jobs["job1"].JobSettings.Tasks[0].NewCluster.AzureAttributes.SpotBidMaxPrice) +} diff --git a/libs/dyn/convert/from_typed.go b/libs/dyn/convert/from_typed.go index 6dcca2b8..4778edb9 100644 --- a/libs/dyn/convert/from_typed.go +++ b/libs/dyn/convert/from_typed.go @@ -6,6 +6,7 @@ import ( "slices" "github.com/databricks/cli/libs/dyn" + "github.com/databricks/cli/libs/dyn/dynvar" ) type fromTypedOptions int @@ -185,6 +186,11 @@ func fromTypedBool(src reflect.Value, ref dyn.Value, options ...fromTypedOptions return dyn.NilValue, nil } return dyn.V(src.Bool()), nil + case dyn.KindString: + // Ignore pure variable references (e.g. ${var.foo}). + if dynvar.IsPureVariableReference(ref.MustString()) { + return ref, nil + } } return dyn.InvalidValue, fmt.Errorf("unhandled type: %s", ref.Kind()) @@ -205,6 +211,11 @@ func fromTypedInt(src reflect.Value, ref dyn.Value, options ...fromTypedOptions) return dyn.NilValue, nil } return dyn.V(src.Int()), nil + case dyn.KindString: + // Ignore pure variable references (e.g. ${var.foo}). + if dynvar.IsPureVariableReference(ref.MustString()) { + return ref, nil + } } return dyn.InvalidValue, fmt.Errorf("unhandled type: %s", ref.Kind()) @@ -225,6 +236,11 @@ func fromTypedFloat(src reflect.Value, ref dyn.Value, options ...fromTypedOption return dyn.NilValue, nil } return dyn.V(src.Float()), nil + case dyn.KindString: + // Ignore pure variable references (e.g. ${var.foo}). + if dynvar.IsPureVariableReference(ref.MustString()) { + return ref, nil + } } return dyn.InvalidValue, fmt.Errorf("unhandled type: %s", ref.Kind()) diff --git a/libs/dyn/convert/from_typed_test.go b/libs/dyn/convert/from_typed_test.go index 5fc2b90f..f7e97fc7 100644 --- a/libs/dyn/convert/from_typed_test.go +++ b/libs/dyn/convert/from_typed_test.go @@ -495,6 +495,14 @@ func TestFromTypedBoolRetainsLocationsIfUnchanged(t *testing.T) { assert.Equal(t, dyn.NewValue(true, dyn.Location{File: "foo"}), nv) } +func TestFromTypedBoolVariableReference(t *testing.T) { + var src bool = true + var ref = dyn.V("${var.foo}") + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, dyn.V("${var.foo}"), nv) +} + func TestFromTypedBoolTypeError(t *testing.T) { var src bool = true var ref = dyn.V("string") @@ -542,6 +550,14 @@ func TestFromTypedIntRetainsLocationsIfUnchanged(t *testing.T) { assert.Equal(t, dyn.NewValue(1234, dyn.Location{File: "foo"}), nv) } +func TestFromTypedIntVariableReference(t *testing.T) { + var src int = 1234 + var ref = dyn.V("${var.foo}") + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, dyn.V("${var.foo}"), nv) +} + func TestFromTypedIntTypeError(t *testing.T) { var src int = 1234 var ref = dyn.V("string") @@ -589,6 +605,14 @@ func TestFromTypedFloatRetainsLocationsIfUnchanged(t *testing.T) { assert.Equal(t, dyn.NewValue(1.23, dyn.Location{File: "foo"}), nv) } +func TestFromTypedFloatVariableReference(t *testing.T) { + var src float64 = 1.23 + var ref = dyn.V("${var.foo}") + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, dyn.V("${var.foo}"), nv) +} + func TestFromTypedFloatTypeError(t *testing.T) { var src float64 = 1.23 var ref = dyn.V("string") diff --git a/libs/dyn/convert/normalize.go b/libs/dyn/convert/normalize.go index e0dfbda2..d6539be9 100644 --- a/libs/dyn/convert/normalize.go +++ b/libs/dyn/convert/normalize.go @@ -8,6 +8,7 @@ import ( "github.com/databricks/cli/libs/diag" "github.com/databricks/cli/libs/dyn" + "github.com/databricks/cli/libs/dyn/dynvar" ) // NormalizeOption is the type for options that can be passed to Normalize. @@ -245,6 +246,11 @@ func (n normalizeOptions) normalizeBool(typ reflect.Type, src dyn.Value) (dyn.Va case "false", "n", "N", "no", "No", "NO", "off", "Off", "OFF": out = false default: + // Return verbatim if it's a pure variable reference. + if dynvar.IsPureVariableReference(src.MustString()) { + return src, nil + } + // Cannot interpret as a boolean. return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindBool, src)) } @@ -266,6 +272,11 @@ func (n normalizeOptions) normalizeInt(typ reflect.Type, src dyn.Value) (dyn.Val var err error out, err = strconv.ParseInt(src.MustString(), 10, 64) if err != nil { + // Return verbatim if it's a pure variable reference. + if dynvar.IsPureVariableReference(src.MustString()) { + return src, nil + } + return dyn.InvalidValue, diags.Append(diag.Diagnostic{ Severity: diag.Error, Summary: fmt.Sprintf("cannot parse %q as an integer", src.MustString()), @@ -290,6 +301,11 @@ func (n normalizeOptions) normalizeFloat(typ reflect.Type, src dyn.Value) (dyn.V var err error out, err = strconv.ParseFloat(src.MustString(), 64) if err != nil { + // Return verbatim if it's a pure variable reference. + if dynvar.IsPureVariableReference(src.MustString()) { + return src, nil + } + return dyn.InvalidValue, diags.Append(diag.Diagnostic{ Severity: diag.Error, Summary: fmt.Sprintf("cannot parse %q as a floating point number", src.MustString()), diff --git a/libs/dyn/convert/normalize_test.go b/libs/dyn/convert/normalize_test.go index 82abc826..a2a6038e 100644 --- a/libs/dyn/convert/normalize_test.go +++ b/libs/dyn/convert/normalize_test.go @@ -490,6 +490,14 @@ func TestNormalizeBoolFromString(t *testing.T) { } } +func TestNormalizeBoolFromStringVariableReference(t *testing.T) { + var typ bool + vin := dyn.V("${var.foo}") + vout, err := Normalize(&typ, vin) + assert.Empty(t, err) + assert.Equal(t, vin, vout) +} + func TestNormalizeBoolFromStringError(t *testing.T) { var typ bool vin := dyn.V("abc") @@ -542,6 +550,14 @@ func TestNormalizeIntFromString(t *testing.T) { assert.Equal(t, dyn.V(int64(123)), vout) } +func TestNormalizeIntFromStringVariableReference(t *testing.T) { + var typ int + vin := dyn.V("${var.foo}") + vout, err := Normalize(&typ, vin) + assert.Empty(t, err) + assert.Equal(t, vin, vout) +} + func TestNormalizeIntFromStringError(t *testing.T) { var typ int vin := dyn.V("abc") @@ -594,6 +610,14 @@ func TestNormalizeFloatFromString(t *testing.T) { assert.Equal(t, dyn.V(1.2), vout) } +func TestNormalizeFloatFromStringVariableReference(t *testing.T) { + var typ float64 + vin := dyn.V("${var.foo}") + vout, err := Normalize(&typ, vin) + assert.Empty(t, err) + assert.Equal(t, vin, vout) +} + func TestNormalizeFloatFromStringError(t *testing.T) { var typ float64 vin := dyn.V("abc") diff --git a/libs/dyn/convert/to_typed.go b/libs/dyn/convert/to_typed.go index 715d3f67..aeaaa9be 100644 --- a/libs/dyn/convert/to_typed.go +++ b/libs/dyn/convert/to_typed.go @@ -6,6 +6,7 @@ import ( "strconv" "github.com/databricks/cli/libs/dyn" + "github.com/databricks/cli/libs/dyn/dynvar" ) func ToTyped(dst any, src dyn.Value) error { @@ -195,6 +196,11 @@ func toTypedBool(dst reflect.Value, src dyn.Value) error { dst.SetBool(false) return nil } + // Ignore pure variable references (e.g. ${var.foo}). + if dynvar.IsPureVariableReference(src.MustString()) { + dst.SetZero() + return nil + } } return TypeError{ @@ -213,6 +219,11 @@ func toTypedInt(dst reflect.Value, src dyn.Value) error { dst.SetInt(i64) return nil } + // Ignore pure variable references (e.g. ${var.foo}). + if dynvar.IsPureVariableReference(src.MustString()) { + dst.SetZero() + return nil + } } return TypeError{ @@ -231,6 +242,11 @@ func toTypedFloat(dst reflect.Value, src dyn.Value) error { dst.SetFloat(f64) return nil } + // Ignore pure variable references (e.g. ${var.foo}). + if dynvar.IsPureVariableReference(src.MustString()) { + dst.SetZero() + return nil + } } return TypeError{ diff --git a/libs/dyn/convert/to_typed_test.go b/libs/dyn/convert/to_typed_test.go index fd399b93..a7c4a6f0 100644 --- a/libs/dyn/convert/to_typed_test.go +++ b/libs/dyn/convert/to_typed_test.go @@ -355,10 +355,17 @@ func TestToTypedBoolFromString(t *testing.T) { } // Other - err := ToTyped(&out, dyn.V("${var.foo}")) + err := ToTyped(&out, dyn.V("some other string")) require.Error(t, err) } +func TestToTypedBoolFromStringVariableReference(t *testing.T) { + var out bool = true + err := ToTyped(&out, dyn.V("${var.foo}")) + require.NoError(t, err) + assert.Equal(t, false, out) +} + func TestToTypedInt(t *testing.T) { var out int err := ToTyped(&out, dyn.V(1234)) @@ -414,6 +421,13 @@ func TestToTypedIntFromStringInt(t *testing.T) { assert.Equal(t, int(123), out) } +func TestToTypedIntFromStringVariableReference(t *testing.T) { + var out int = 123 + err := ToTyped(&out, dyn.V("${var.foo}")) + require.NoError(t, err) + assert.Equal(t, int(0), out) +} + func TestToTypedFloat32(t *testing.T) { var out float32 err := ToTyped(&out, dyn.V(float32(1.0))) @@ -467,3 +481,17 @@ func TestToTypedFloat64FromString(t *testing.T) { require.NoError(t, err) assert.Equal(t, float64(1.2), out) } + +func TestToTypedFloat32FromStringVariableReference(t *testing.T) { + var out float32 = 1.0 + err := ToTyped(&out, dyn.V("${var.foo}")) + require.NoError(t, err) + assert.Equal(t, float32(0.0), out) +} + +func TestToTypedFloat64FromStringVariableReference(t *testing.T) { + var out float64 = 1.0 + err := ToTyped(&out, dyn.V("${var.foo}")) + require.NoError(t, err) + assert.Equal(t, float64(0.0), out) +} diff --git a/libs/dyn/dynvar/ref.go b/libs/dyn/dynvar/ref.go index e4616c52..a2047032 100644 --- a/libs/dyn/dynvar/ref.go +++ b/libs/dyn/dynvar/ref.go @@ -67,3 +67,7 @@ func (v ref) references() []string { } return out } + +func IsPureVariableReference(s string) bool { + return len(s) > 0 && re.FindString(s) == s +} diff --git a/libs/dyn/dynvar/ref_test.go b/libs/dyn/dynvar/ref_test.go index b3066276..09223736 100644 --- a/libs/dyn/dynvar/ref_test.go +++ b/libs/dyn/dynvar/ref_test.go @@ -44,3 +44,10 @@ func TestNewRefInvalidPattern(t *testing.T) { require.False(t, ok, "should not match invalid pattern: %s", v) } } + +func TestIsPureVariableReference(t *testing.T) { + assert.False(t, IsPureVariableReference("")) + assert.False(t, IsPureVariableReference("${foo.bar} suffix")) + assert.False(t, IsPureVariableReference("prefix ${foo.bar}")) + assert.True(t, IsPureVariableReference("${foo.bar}")) +}