Avoid infinite recursion when normalizing a recursive type (#1213)

## Changes

This is a follow-up to #1211 prompted by the addition of a recursive
type in the Go SDK v0.31.0 (`jobs.ForEachTask`).

When populating missing fields with their zero values we must not
inadvertently recurse into a recursive type.

## Tests

New unit test fails with a stack overflow if the fix if the check is
disabled.
This commit is contained in:
Pieter Noordhuis 2024-02-16 13:56:02 +01:00 committed by GitHub
parent 788ec81785
commit ea8daf1f97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 55 additions and 17 deletions

View File

@ -3,6 +3,7 @@ package convert
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"slices"
"strconv" "strconv"
"github.com/databricks/cli/libs/diag" "github.com/databricks/cli/libs/diag"
@ -31,21 +32,21 @@ func Normalize(dst any, src dyn.Value, opts ...NormalizeOption) (dyn.Value, diag
} }
} }
return n.normalizeType(reflect.TypeOf(dst), src) return n.normalizeType(reflect.TypeOf(dst), src, []reflect.Type{})
} }
func (n normalizeOptions) normalizeType(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) { func (n normalizeOptions) normalizeType(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) {
for typ.Kind() == reflect.Pointer { for typ.Kind() == reflect.Pointer {
typ = typ.Elem() typ = typ.Elem()
} }
switch typ.Kind() { switch typ.Kind() {
case reflect.Struct: case reflect.Struct:
return n.normalizeStruct(typ, src) return n.normalizeStruct(typ, src, append(seen, typ))
case reflect.Map: case reflect.Map:
return n.normalizeMap(typ, src) return n.normalizeMap(typ, src, append(seen, typ))
case reflect.Slice: case reflect.Slice:
return n.normalizeSlice(typ, src) return n.normalizeSlice(typ, src, append(seen, typ))
case reflect.String: case reflect.String:
return n.normalizeString(typ, src) return n.normalizeString(typ, src)
case reflect.Bool: case reflect.Bool:
@ -67,7 +68,7 @@ func typeMismatch(expected dyn.Kind, src dyn.Value) diag.Diagnostic {
} }
} }
func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) { func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) {
var diags diag.Diagnostics var diags diag.Diagnostics
switch src.Kind() { switch src.Kind() {
@ -86,7 +87,7 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.
} }
// Normalize the value according to the field type. // Normalize the value according to the field type.
v, err := n.normalizeType(typ.FieldByIndex(index).Type, v) v, err := n.normalizeType(typ.FieldByIndex(index).Type, v, seen)
if err != nil { if err != nil {
diags = diags.Extend(err) diags = diags.Extend(err)
// Skip the element if it cannot be normalized. // Skip the element if it cannot be normalized.
@ -115,20 +116,26 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.
ftyp = ftyp.Elem() ftyp = ftyp.Elem()
} }
// Skip field if we have already seen its type to avoid infinite recursion
// when filling in the zero value of a recursive type.
if slices.Contains(seen, ftyp) {
continue
}
var v dyn.Value var v dyn.Value
switch ftyp.Kind() { switch ftyp.Kind() {
case reflect.Struct, reflect.Map: case reflect.Struct, reflect.Map:
v, _ = n.normalizeType(ftyp, dyn.V(map[string]dyn.Value{})) v, _ = n.normalizeType(ftyp, dyn.V(map[string]dyn.Value{}), seen)
case reflect.Slice: case reflect.Slice:
v, _ = n.normalizeType(ftyp, dyn.V([]dyn.Value{})) v, _ = n.normalizeType(ftyp, dyn.V([]dyn.Value{}), seen)
case reflect.String: case reflect.String:
v, _ = n.normalizeType(ftyp, dyn.V("")) v, _ = n.normalizeType(ftyp, dyn.V(""), seen)
case reflect.Bool: case reflect.Bool:
v, _ = n.normalizeType(ftyp, dyn.V(false)) v, _ = n.normalizeType(ftyp, dyn.V(false), seen)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v, _ = n.normalizeType(ftyp, dyn.V(int64(0))) v, _ = n.normalizeType(ftyp, dyn.V(int64(0)), seen)
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
v, _ = n.normalizeType(ftyp, dyn.V(float64(0))) v, _ = n.normalizeType(ftyp, dyn.V(float64(0)), seen)
default: default:
// Skip fields for which we do not have a natural [dyn.Value] equivalent. // Skip fields for which we do not have a natural [dyn.Value] equivalent.
// For example, we don't handle reflect.Complex* and reflect.Uint* types. // For example, we don't handle reflect.Complex* and reflect.Uint* types.
@ -147,7 +154,7 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.
return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindMap, src)) return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindMap, src))
} }
func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) { func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) {
var diags diag.Diagnostics var diags diag.Diagnostics
switch src.Kind() { switch src.Kind() {
@ -155,7 +162,7 @@ func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value) (dyn.Val
out := make(map[string]dyn.Value) out := make(map[string]dyn.Value)
for k, v := range src.MustMap() { for k, v := range src.MustMap() {
// Normalize the value according to the map element type. // Normalize the value according to the map element type.
v, err := n.normalizeType(typ.Elem(), v) v, err := n.normalizeType(typ.Elem(), v, seen)
if err != nil { if err != nil {
diags = diags.Extend(err) diags = diags.Extend(err)
// Skip the element if it cannot be normalized. // Skip the element if it cannot be normalized.
@ -175,7 +182,7 @@ func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value) (dyn.Val
return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindMap, src)) return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindMap, src))
} }
func (n normalizeOptions) normalizeSlice(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) { func (n normalizeOptions) normalizeSlice(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) {
var diags diag.Diagnostics var diags diag.Diagnostics
switch src.Kind() { switch src.Kind() {
@ -183,7 +190,7 @@ func (n normalizeOptions) normalizeSlice(typ reflect.Type, src dyn.Value) (dyn.V
out := make([]dyn.Value, 0, len(src.MustSequence())) out := make([]dyn.Value, 0, len(src.MustSequence()))
for _, v := range src.MustSequence() { for _, v := range src.MustSequence() {
// Normalize the value according to the slice element type. // Normalize the value according to the slice element type.
v, err := n.normalizeType(typ.Elem(), v) v, err := n.normalizeType(typ.Elem(), v, seen)
if err != nil { if err != nil {
diags = diags.Extend(err) diags = diags.Extend(err)
// Skip the element if it cannot be normalized. // Skip the element if it cannot be normalized.

View File

@ -189,6 +189,37 @@ func TestNormalizeStructIncludeMissingFields(t *testing.T) {
}), vout) }), vout)
} }
func TestNormalizeStructIncludeMissingFieldsOnRecursiveType(t *testing.T) {
type Tmp struct {
// Verify that structs are recursively normalized if not set.
Ptr *Tmp `json:"ptr"`
// Verify that primitive types are zero-initialized if not set.
String string `json:"string"`
}
var typ Tmp
vin := dyn.V(map[string]dyn.Value{
"ptr": dyn.V(map[string]dyn.Value{
"ptr": dyn.V(map[string]dyn.Value{
"string": dyn.V("already set"),
}),
}),
})
vout, err := Normalize(typ, vin, IncludeMissingFields)
assert.Empty(t, err)
assert.Equal(t, dyn.V(map[string]dyn.Value{
"ptr": dyn.V(map[string]dyn.Value{
"ptr": dyn.V(map[string]dyn.Value{
// Note: the ptr field is not zero-initialized because that would recurse.
"string": dyn.V("already set"),
}),
"string": dyn.V(""),
}),
"string": dyn.V(""),
}), vout)
}
func TestNormalizeMap(t *testing.T) { func TestNormalizeMap(t *testing.T) {
var typ map[string]string var typ map[string]string
vin := dyn.V(map[string]dyn.Value{ vin := dyn.V(map[string]dyn.Value{