Avoid infinite recursion when normalizing a recursive type (#1213)

## Changes

This is a follow-up to #1211 prompted by the addition of a recursive
type in the Go SDK v0.31.0 (`jobs.ForEachTask`).

When populating missing fields with their zero values we must not
inadvertently recurse into a recursive type.

## Tests

New unit test fails with a stack overflow if the fix if the check is
disabled.
This commit is contained in:
Pieter Noordhuis 2024-02-16 13:56:02 +01:00 committed by GitHub
parent 788ec81785
commit ea8daf1f97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 55 additions and 17 deletions

View File

@ -3,6 +3,7 @@ package convert
import (
"fmt"
"reflect"
"slices"
"strconv"
"github.com/databricks/cli/libs/diag"
@ -31,21 +32,21 @@ func Normalize(dst any, src dyn.Value, opts ...NormalizeOption) (dyn.Value, diag
}
}
return n.normalizeType(reflect.TypeOf(dst), src)
return n.normalizeType(reflect.TypeOf(dst), src, []reflect.Type{})
}
func (n normalizeOptions) normalizeType(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
func (n normalizeOptions) normalizeType(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) {
for typ.Kind() == reflect.Pointer {
typ = typ.Elem()
}
switch typ.Kind() {
case reflect.Struct:
return n.normalizeStruct(typ, src)
return n.normalizeStruct(typ, src, append(seen, typ))
case reflect.Map:
return n.normalizeMap(typ, src)
return n.normalizeMap(typ, src, append(seen, typ))
case reflect.Slice:
return n.normalizeSlice(typ, src)
return n.normalizeSlice(typ, src, append(seen, typ))
case reflect.String:
return n.normalizeString(typ, src)
case reflect.Bool:
@ -67,7 +68,7 @@ func typeMismatch(expected dyn.Kind, src dyn.Value) diag.Diagnostic {
}
}
func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) {
var diags diag.Diagnostics
switch src.Kind() {
@ -86,7 +87,7 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.
}
// Normalize the value according to the field type.
v, err := n.normalizeType(typ.FieldByIndex(index).Type, v)
v, err := n.normalizeType(typ.FieldByIndex(index).Type, v, seen)
if err != nil {
diags = diags.Extend(err)
// Skip the element if it cannot be normalized.
@ -115,20 +116,26 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.
ftyp = ftyp.Elem()
}
// Skip field if we have already seen its type to avoid infinite recursion
// when filling in the zero value of a recursive type.
if slices.Contains(seen, ftyp) {
continue
}
var v dyn.Value
switch ftyp.Kind() {
case reflect.Struct, reflect.Map:
v, _ = n.normalizeType(ftyp, dyn.V(map[string]dyn.Value{}))
v, _ = n.normalizeType(ftyp, dyn.V(map[string]dyn.Value{}), seen)
case reflect.Slice:
v, _ = n.normalizeType(ftyp, dyn.V([]dyn.Value{}))
v, _ = n.normalizeType(ftyp, dyn.V([]dyn.Value{}), seen)
case reflect.String:
v, _ = n.normalizeType(ftyp, dyn.V(""))
v, _ = n.normalizeType(ftyp, dyn.V(""), seen)
case reflect.Bool:
v, _ = n.normalizeType(ftyp, dyn.V(false))
v, _ = n.normalizeType(ftyp, dyn.V(false), seen)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v, _ = n.normalizeType(ftyp, dyn.V(int64(0)))
v, _ = n.normalizeType(ftyp, dyn.V(int64(0)), seen)
case reflect.Float32, reflect.Float64:
v, _ = n.normalizeType(ftyp, dyn.V(float64(0)))
v, _ = n.normalizeType(ftyp, dyn.V(float64(0)), seen)
default:
// Skip fields for which we do not have a natural [dyn.Value] equivalent.
// For example, we don't handle reflect.Complex* and reflect.Uint* types.
@ -147,7 +154,7 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.
return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindMap, src))
}
func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) {
var diags diag.Diagnostics
switch src.Kind() {
@ -155,7 +162,7 @@ func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value) (dyn.Val
out := make(map[string]dyn.Value)
for k, v := range src.MustMap() {
// Normalize the value according to the map element type.
v, err := n.normalizeType(typ.Elem(), v)
v, err := n.normalizeType(typ.Elem(), v, seen)
if err != nil {
diags = diags.Extend(err)
// Skip the element if it cannot be normalized.
@ -175,7 +182,7 @@ func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value) (dyn.Val
return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindMap, src))
}
func (n normalizeOptions) normalizeSlice(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
func (n normalizeOptions) normalizeSlice(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) {
var diags diag.Diagnostics
switch src.Kind() {
@ -183,7 +190,7 @@ func (n normalizeOptions) normalizeSlice(typ reflect.Type, src dyn.Value) (dyn.V
out := make([]dyn.Value, 0, len(src.MustSequence()))
for _, v := range src.MustSequence() {
// Normalize the value according to the slice element type.
v, err := n.normalizeType(typ.Elem(), v)
v, err := n.normalizeType(typ.Elem(), v, seen)
if err != nil {
diags = diags.Extend(err)
// Skip the element if it cannot be normalized.

View File

@ -189,6 +189,37 @@ func TestNormalizeStructIncludeMissingFields(t *testing.T) {
}), vout)
}
func TestNormalizeStructIncludeMissingFieldsOnRecursiveType(t *testing.T) {
type Tmp struct {
// Verify that structs are recursively normalized if not set.
Ptr *Tmp `json:"ptr"`
// Verify that primitive types are zero-initialized if not set.
String string `json:"string"`
}
var typ Tmp
vin := dyn.V(map[string]dyn.Value{
"ptr": dyn.V(map[string]dyn.Value{
"ptr": dyn.V(map[string]dyn.Value{
"string": dyn.V("already set"),
}),
}),
})
vout, err := Normalize(typ, vin, IncludeMissingFields)
assert.Empty(t, err)
assert.Equal(t, dyn.V(map[string]dyn.Value{
"ptr": dyn.V(map[string]dyn.Value{
"ptr": dyn.V(map[string]dyn.Value{
// Note: the ptr field is not zero-initialized because that would recurse.
"string": dyn.V("already set"),
}),
"string": dyn.V(""),
}),
"string": dyn.V(""),
}), vout)
}
func TestNormalizeMap(t *testing.T) {
var typ map[string]string
vin := dyn.V(map[string]dyn.Value{