mirror of https://github.com/databricks/cli.git
Avoid infinite recursion when normalizing a recursive type (#1213)
## Changes This is a follow-up to #1211 prompted by the addition of a recursive type in the Go SDK v0.31.0 (`jobs.ForEachTask`). When populating missing fields with their zero values we must not inadvertently recurse into a recursive type. ## Tests New unit test fails with a stack overflow if the fix if the check is disabled.
This commit is contained in:
parent
788ec81785
commit
ea8daf1f97
|
@ -3,6 +3,7 @@ package convert
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/databricks/cli/libs/diag"
|
"github.com/databricks/cli/libs/diag"
|
||||||
|
@ -31,21 +32,21 @@ func Normalize(dst any, src dyn.Value, opts ...NormalizeOption) (dyn.Value, diag
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return n.normalizeType(reflect.TypeOf(dst), src)
|
return n.normalizeType(reflect.TypeOf(dst), src, []reflect.Type{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n normalizeOptions) normalizeType(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
|
func (n normalizeOptions) normalizeType(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) {
|
||||||
for typ.Kind() == reflect.Pointer {
|
for typ.Kind() == reflect.Pointer {
|
||||||
typ = typ.Elem()
|
typ = typ.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
switch typ.Kind() {
|
switch typ.Kind() {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
return n.normalizeStruct(typ, src)
|
return n.normalizeStruct(typ, src, append(seen, typ))
|
||||||
case reflect.Map:
|
case reflect.Map:
|
||||||
return n.normalizeMap(typ, src)
|
return n.normalizeMap(typ, src, append(seen, typ))
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
return n.normalizeSlice(typ, src)
|
return n.normalizeSlice(typ, src, append(seen, typ))
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
return n.normalizeString(typ, src)
|
return n.normalizeString(typ, src)
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
|
@ -67,7 +68,7 @@ func typeMismatch(expected dyn.Kind, src dyn.Value) diag.Diagnostic {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
|
func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) {
|
||||||
var diags diag.Diagnostics
|
var diags diag.Diagnostics
|
||||||
|
|
||||||
switch src.Kind() {
|
switch src.Kind() {
|
||||||
|
@ -86,7 +87,7 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normalize the value according to the field type.
|
// Normalize the value according to the field type.
|
||||||
v, err := n.normalizeType(typ.FieldByIndex(index).Type, v)
|
v, err := n.normalizeType(typ.FieldByIndex(index).Type, v, seen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
diags = diags.Extend(err)
|
diags = diags.Extend(err)
|
||||||
// Skip the element if it cannot be normalized.
|
// Skip the element if it cannot be normalized.
|
||||||
|
@ -115,20 +116,26 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.
|
||||||
ftyp = ftyp.Elem()
|
ftyp = ftyp.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip field if we have already seen its type to avoid infinite recursion
|
||||||
|
// when filling in the zero value of a recursive type.
|
||||||
|
if slices.Contains(seen, ftyp) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
var v dyn.Value
|
var v dyn.Value
|
||||||
switch ftyp.Kind() {
|
switch ftyp.Kind() {
|
||||||
case reflect.Struct, reflect.Map:
|
case reflect.Struct, reflect.Map:
|
||||||
v, _ = n.normalizeType(ftyp, dyn.V(map[string]dyn.Value{}))
|
v, _ = n.normalizeType(ftyp, dyn.V(map[string]dyn.Value{}), seen)
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
v, _ = n.normalizeType(ftyp, dyn.V([]dyn.Value{}))
|
v, _ = n.normalizeType(ftyp, dyn.V([]dyn.Value{}), seen)
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
v, _ = n.normalizeType(ftyp, dyn.V(""))
|
v, _ = n.normalizeType(ftyp, dyn.V(""), seen)
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
v, _ = n.normalizeType(ftyp, dyn.V(false))
|
v, _ = n.normalizeType(ftyp, dyn.V(false), seen)
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
v, _ = n.normalizeType(ftyp, dyn.V(int64(0)))
|
v, _ = n.normalizeType(ftyp, dyn.V(int64(0)), seen)
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
v, _ = n.normalizeType(ftyp, dyn.V(float64(0)))
|
v, _ = n.normalizeType(ftyp, dyn.V(float64(0)), seen)
|
||||||
default:
|
default:
|
||||||
// Skip fields for which we do not have a natural [dyn.Value] equivalent.
|
// Skip fields for which we do not have a natural [dyn.Value] equivalent.
|
||||||
// For example, we don't handle reflect.Complex* and reflect.Uint* types.
|
// For example, we don't handle reflect.Complex* and reflect.Uint* types.
|
||||||
|
@ -147,7 +154,7 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.
|
||||||
return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindMap, src))
|
return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindMap, src))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
|
func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) {
|
||||||
var diags diag.Diagnostics
|
var diags diag.Diagnostics
|
||||||
|
|
||||||
switch src.Kind() {
|
switch src.Kind() {
|
||||||
|
@ -155,7 +162,7 @@ func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value) (dyn.Val
|
||||||
out := make(map[string]dyn.Value)
|
out := make(map[string]dyn.Value)
|
||||||
for k, v := range src.MustMap() {
|
for k, v := range src.MustMap() {
|
||||||
// Normalize the value according to the map element type.
|
// Normalize the value according to the map element type.
|
||||||
v, err := n.normalizeType(typ.Elem(), v)
|
v, err := n.normalizeType(typ.Elem(), v, seen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
diags = diags.Extend(err)
|
diags = diags.Extend(err)
|
||||||
// Skip the element if it cannot be normalized.
|
// Skip the element if it cannot be normalized.
|
||||||
|
@ -175,7 +182,7 @@ func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value) (dyn.Val
|
||||||
return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindMap, src))
|
return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindMap, src))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n normalizeOptions) normalizeSlice(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
|
func (n normalizeOptions) normalizeSlice(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) {
|
||||||
var diags diag.Diagnostics
|
var diags diag.Diagnostics
|
||||||
|
|
||||||
switch src.Kind() {
|
switch src.Kind() {
|
||||||
|
@ -183,7 +190,7 @@ func (n normalizeOptions) normalizeSlice(typ reflect.Type, src dyn.Value) (dyn.V
|
||||||
out := make([]dyn.Value, 0, len(src.MustSequence()))
|
out := make([]dyn.Value, 0, len(src.MustSequence()))
|
||||||
for _, v := range src.MustSequence() {
|
for _, v := range src.MustSequence() {
|
||||||
// Normalize the value according to the slice element type.
|
// Normalize the value according to the slice element type.
|
||||||
v, err := n.normalizeType(typ.Elem(), v)
|
v, err := n.normalizeType(typ.Elem(), v, seen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
diags = diags.Extend(err)
|
diags = diags.Extend(err)
|
||||||
// Skip the element if it cannot be normalized.
|
// Skip the element if it cannot be normalized.
|
||||||
|
|
|
@ -189,6 +189,37 @@ func TestNormalizeStructIncludeMissingFields(t *testing.T) {
|
||||||
}), vout)
|
}), vout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizeStructIncludeMissingFieldsOnRecursiveType(t *testing.T) {
|
||||||
|
type Tmp struct {
|
||||||
|
// Verify that structs are recursively normalized if not set.
|
||||||
|
Ptr *Tmp `json:"ptr"`
|
||||||
|
|
||||||
|
// Verify that primitive types are zero-initialized if not set.
|
||||||
|
String string `json:"string"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var typ Tmp
|
||||||
|
vin := dyn.V(map[string]dyn.Value{
|
||||||
|
"ptr": dyn.V(map[string]dyn.Value{
|
||||||
|
"ptr": dyn.V(map[string]dyn.Value{
|
||||||
|
"string": dyn.V("already set"),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
vout, err := Normalize(typ, vin, IncludeMissingFields)
|
||||||
|
assert.Empty(t, err)
|
||||||
|
assert.Equal(t, dyn.V(map[string]dyn.Value{
|
||||||
|
"ptr": dyn.V(map[string]dyn.Value{
|
||||||
|
"ptr": dyn.V(map[string]dyn.Value{
|
||||||
|
// Note: the ptr field is not zero-initialized because that would recurse.
|
||||||
|
"string": dyn.V("already set"),
|
||||||
|
}),
|
||||||
|
"string": dyn.V(""),
|
||||||
|
}),
|
||||||
|
"string": dyn.V(""),
|
||||||
|
}), vout)
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalizeMap(t *testing.T) {
|
func TestNormalizeMap(t *testing.T) {
|
||||||
var typ map[string]string
|
var typ map[string]string
|
||||||
vin := dyn.V(map[string]dyn.Value{
|
vin := dyn.V(map[string]dyn.Value{
|
||||||
|
|
Loading…
Reference in New Issue