diff --git a/bundle/config/interpolation/interpolation.go b/bundle/config/interpolation/interpolation.go index 09dfc79ae..8eab03b4c 100644 --- a/bundle/config/interpolation/interpolation.go +++ b/bundle/config/interpolation/interpolation.go @@ -6,9 +6,12 @@ import ( "fmt" "reflect" "regexp" + "sort" "strings" "github.com/databricks/bricks/bundle" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" ) const Delimiter = "." @@ -63,7 +66,14 @@ func (s *stringField) interpolate(fns []LookupFunction, lookup map[string]string } type accumulator struct { + // all string fields in the bundle config strings map[string]*stringField + + // contains path -> resolved_string mapping for string fields in the config + // The resolved strings will NOT contain any variable references that could + // have been resolved, however there might still be references that cannot + // be resolved + memo map[string]string } // jsonFieldName returns the name in a field's `json` tag. @@ -138,25 +148,7 @@ func (a *accumulator) walk(scope []string, rv reflect.Value, s setter) { } } -// Gathers the strings for a list of paths. -// The fields in these paths may not depend on other fields, -// as we don't support full DAG lookup yet (only single level). -func (a *accumulator) gather(paths []string) (map[string]string, error) { - var out = make(map[string]string) - for _, path := range paths { - f, ok := a.strings[path] - if !ok { - return nil, fmt.Errorf("%s is not defined", path) - } - deps := f.dependsOn() - if len(deps) > 0 { - return nil, fmt.Errorf("%s depends on %s", path, strings.Join(deps, ", ")) - } - out[path] = f.Get() - } - return out, nil -} - +// walk and gather all string fields in the config func (a *accumulator) start(v any) { rv := reflect.ValueOf(v) if rv.Type().Kind() != reflect.Pointer { @@ -168,25 +160,64 @@ func (a *accumulator) start(v any) { } a.strings = make(map[string]*stringField) + a.memo = make(map[string]string) a.walk([]string{}, rv, nilSetter{}) } -func (a *accumulator) expand(fns ...LookupFunction) error { - for path, v := range a.strings { - ds := v.dependsOn() - if len(ds) == 0 { - continue - } - - // Create map to be used for interpolation - m, err := a.gather(ds) - if err != nil { - return fmt.Errorf("cannot interpolate %s: %w", path, err) - } - - v.interpolate(fns, m) +// recursively interpolate variables in a depth first manner +func (a *accumulator) Resolve(path string, seenPaths []string, fns ...LookupFunction) error { + // return early if the path is already resolved + if _, ok := a.memo[path]; ok { + return nil } + // fetch the string node to resolve + field, ok := a.strings[path] + if !ok { + return fmt.Errorf("could not find string field with path %s", path) + } + + // return early if the string field has no variables to interpolate + if len(field.dependsOn()) == 0 { + a.memo[path] = field.Get() + return nil + } + + // resolve all variables refered in the root string field + for _, childFieldPath := range field.dependsOn() { + // error if there is a loop in variable interpolation + if slices.Contains(seenPaths, childFieldPath) { + return fmt.Errorf("cycle detected in field resolution: %s", strings.Join(append(seenPaths, childFieldPath), " -> ")) + } + + // recursive resolve variables in the child fields + err := a.Resolve(childFieldPath, append(seenPaths, childFieldPath), fns...) + if err != nil { + return err + } + } + + // interpolate root string once all variable references in it have been resolved + field.interpolate(fns, a.memo) + + // record interpolated string in memo + a.memo[path] = field.Get() + return nil +} + +// Interpolate all string fields in the config +func (a *accumulator) expand(fns ...LookupFunction) error { + // sorting paths for stable order of iteration + paths := maps.Keys(a.strings) + sort.Strings(paths) + + // iterate over paths for all strings fields in the config + for _, path := range paths { + err := a.Resolve(path, []string{path}, fns...) + if err != nil { + return err + } + } return nil } diff --git a/bundle/config/interpolation/interpolation_test.go b/bundle/config/interpolation/interpolation_test.go index bce51225e..eb5848fd6 100644 --- a/bundle/config/interpolation/interpolation_test.go +++ b/bundle/config/interpolation/interpolation_test.go @@ -97,3 +97,31 @@ func TestInterpolationWithMap(t *testing.T) { assert.Equal(t, "a", f.F["a"]) assert.Equal(t, "a", f.F["b"]) } + +func TestInterpolationWithResursiveVariableReferences(t *testing.T) { + f := foo{ + A: "a", + B: "(${a})", + C: "${a} ${b}", + } + + err := expand(&f) + require.NoError(t, err) + + assert.Equal(t, "a", f.A) + assert.Equal(t, "(a)", f.B) + assert.Equal(t, "a (a)", f.C) +} + +func TestInterpolationVariableLoopError(t *testing.T) { + d := "${b}" + f := foo{ + A: "a", + B: "${c}", + C: "${d}", + D: &d, + } + + err := expand(&f) + assert.ErrorContains(t, err, "cycle detected in field resolution: b -> c -> d -> b") +} diff --git a/bundle/tests/interpolation/bundle.yml b/bundle/tests/interpolation/bundle.yml new file mode 100644 index 000000000..a31af0278 --- /dev/null +++ b/bundle/tests/interpolation/bundle.yml @@ -0,0 +1,10 @@ +bundle: + name: foo ${workspace.profile} + +workspace: + profile: bar + +resources: + jobs: + my_job: + name: "${bundle.name} | ${workspace.profile}" diff --git a/bundle/tests/interpolation_test.go b/bundle/tests/interpolation_test.go new file mode 100644 index 000000000..cd1f506a7 --- /dev/null +++ b/bundle/tests/interpolation_test.go @@ -0,0 +1,23 @@ +package config_tests + +import ( + "context" + "testing" + + "github.com/databricks/bricks/bundle" + "github.com/databricks/bricks/bundle/config/interpolation" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInterpolation(t *testing.T) { + b := load(t, "./interpolation") + err := bundle.Apply(context.Background(), b, []bundle.Mutator{ + interpolation.Interpolate( + interpolation.IncludeLookupsInPath("bundle"), + interpolation.IncludeLookupsInPath("workspace"), + )}) + require.NoError(t, err) + assert.Equal(t, "foo bar", b.Config.Bundle.Name) + assert.Equal(t, "foo bar | bar", b.Config.Resources.Jobs["my_job"].Name) +}