diff --git a/bundle/schema/schema.go b/bundle/schema/schema.go index aaf0868f3..f70e47663 100644 --- a/bundle/schema/schema.go +++ b/bundle/schema/schema.go @@ -61,11 +61,10 @@ between json schema types and golang types for details visit: https://json-schema.org/understanding-json-schema/reference/object.html#properties */ func NewSchema(golangType reflect.Type, docs *Docs) (*Schema, error) { - seenTypes := map[reflect.Type]struct{}{} - debugTrace := list.New() - schema, err := safeToSchema(golangType, docs, "", seenTypes, debugTrace) + tracker := newTracker() + schema, err := safeToSchema(golangType, docs, "", tracker) if err != nil { - return nil, errWithTrace(err.Error(), debugTrace) + return nil, tracker.errWithTrace(err.Error()) } return schema, nil } @@ -106,20 +105,8 @@ func javascriptType(golangType reflect.Type) (JavascriptType, error) { } } -func errWithTrace(prefix string, trace *list.List) error { - traceString := "root" - curr := trace.Front() - for curr != nil { - if curr.Value.(string) != "" { - traceString += " -> " + curr.Value.(string) - } - curr = curr.Next() - } - return fmt.Errorf("[ERROR] " + prefix + ". traversal trace: " + traceString) -} - // A wrapper over toSchema function to detect cycles in the bundle config struct -func safeToSchema(golangType reflect.Type, docs *Docs, debugTraceId string, seenTypes map[reflect.Type]struct{}, debugTrace *list.List) (*Schema, error) { +func safeToSchema(golangType reflect.Type, docs *Docs, traceId string, tracker *tracker) (*Schema, error) { // WE ERROR OUT IF THERE ARE CYCLES IN THE JSON SCHEMA // There are mechanisms to deal with cycles though recursive identifiers in json // schema. However if we use them, we would need to make sure we are able to detect @@ -127,24 +114,16 @@ func safeToSchema(golangType reflect.Type, docs *Docs, debugTraceId string, seen // // see: https://json-schema.org/understanding-json-schema/structuring.html#recursion // for details - _, ok := seenTypes[golangType] - if ok { - fmt.Println("[DEBUG] traceSet: ", seenTypes) + if tracker.hasCycle(golangType) { return nil, fmt.Errorf("cycle detected") } - // Update set of types in current path - seenTypes[golangType] = struct{}{} - - // Add the json tag name of struct field to debug trace - debugTrace.PushBack(debugTraceId) - props, err := toSchema(golangType, docs, seenTypes, debugTrace) + tracker.step(golangType, traceId) + props, err := toSchema(golangType, docs, tracker) if err != nil { return nil, err } - back := debugTrace.Back() - debugTrace.Remove(back) - delete(seenTypes, golangType) + tracker.undoStep(golangType) return props, nil } @@ -188,10 +167,10 @@ func getStructFields(golangType reflect.Type) []reflect.StructField { // Used to identify cycles. // debugTrace: linked list of golang types encounted. In case of errors this // helps log where the error originated from -func toSchema(golangType reflect.Type, docs *Docs, seenTypes map[reflect.Type]struct{}, debugTrace *list.List) (*Schema, error) { +func toSchema(golangType reflect.Type, docs *Docs, tracker *tracker) (*Schema, error) { // *Struct and Struct generate identical json schemas if golangType.Kind() == reflect.Pointer { - return safeToSchema(golangType.Elem(), docs, "", seenTypes, debugTrace) + return safeToSchema(golangType.Elem(), docs, "", tracker) } if golangType.Kind() == reflect.Interface { @@ -216,7 +195,7 @@ func toSchema(golangType reflect.Type, docs *Docs, seenTypes map[reflect.Type]st if err != nil { return nil, err } - elemProps, err := safeToSchema(elemGolangType, docs, "", seenTypes, debugTrace) + elemProps, err := safeToSchema(elemGolangType, docs, "", tracker) if err != nil { return nil, err } @@ -235,7 +214,7 @@ func toSchema(golangType reflect.Type, docs *Docs, seenTypes map[reflect.Type]st if golangType.Key().Kind() != reflect.String { return nil, fmt.Errorf("only string keyed maps allowed") } - additionalProperties, err = safeToSchema(golangType.Elem(), docs, "", seenTypes, debugTrace) + additionalProperties, err = safeToSchema(golangType.Elem(), docs, "", tracker) if err != nil { return nil, err } @@ -278,7 +257,7 @@ func toSchema(golangType reflect.Type, docs *Docs, seenTypes map[reflect.Type]st } // compute Schema.Properties for the child recursively - fieldProps, err := safeToSchema(child.Type, childDocs, childName, seenTypes, debugTrace) + fieldProps, err := safeToSchema(child.Type, childDocs, childName, tracker) if err != nil { return nil, err } diff --git a/bundle/schema/schema_test.go b/bundle/schema/schema_test.go index c512c72b8..a5161a3f2 100644 --- a/bundle/schema/schema_test.go +++ b/bundle/schema/schema_test.go @@ -1,7 +1,6 @@ package schema import ( - "container/list" "encoding/json" "reflect" "testing" @@ -661,17 +660,18 @@ func TestEmbeddedStructSchema(t *testing.T) { } func TestErrorWithTrace(t *testing.T) { - debugTrace := list.New() - err := errWithTrace("with empty trace", debugTrace) + tracker := newTracker() + dummyType := reflect.TypeOf(struct{}{}) + err := tracker.errWithTrace("with empty trace") assert.ErrorContains(t, err, "[ERROR] with empty trace. traversal trace: root") - debugTrace.PushBack("resources") - err = errWithTrace("with depth = 1", debugTrace) + tracker.step(dummyType, "resources") + err = tracker.errWithTrace("with depth = 1") assert.ErrorContains(t, err, "[ERROR] with depth = 1. traversal trace: root -> resources") - debugTrace.PushBack("pipelines") - debugTrace.PushBack("datasets") - err = errWithTrace("with depth = 4", debugTrace) + tracker.step(dummyType, "pipelines") + tracker.step(dummyType, "datasets") + err = tracker.errWithTrace("with depth = 4") assert.ErrorContains(t, err, "[ERROR] with depth = 4. traversal trace: root -> resources -> pipelines -> datasets") } diff --git a/bundle/schema/tracker.go b/bundle/schema/tracker.go new file mode 100644 index 000000000..a57a88d85 --- /dev/null +++ b/bundle/schema/tracker.go @@ -0,0 +1,55 @@ +package schema + +import ( + "container/list" + "fmt" + "reflect" +) + +type tracker struct { + // Types encountered in path of reaching the current type. Used to deletect + // cycles + seenTypes map[reflect.Type]struct{} + + // List of names from json tag encountered while reaching current type. This + // is logged on any error so we know on which type an error occured + debugTrace *list.List +} + +func newTracker() *tracker { + return &tracker{ + seenTypes: map[reflect.Type]struct{}{}, + debugTrace: list.New(), + } +} + +func (t *tracker) errWithTrace(prefix string) error { + traceString := "root" + curr := t.debugTrace.Front() + for curr != nil { + if curr.Value.(string) != "" { + traceString += " -> " + curr.Value.(string) + } + curr = curr.Next() + } + return fmt.Errorf("[ERROR] " + prefix + ". traversal trace: " + traceString) +} + +func (t *tracker) hasCycle(golangType reflect.Type) bool { + _, ok := t.seenTypes[golangType] + if !ok { + fmt.Println("[DEBUG] traceSet for cycle: ", t.seenTypes) + } + return ok +} + +func (t *tracker) step(nodeType reflect.Type, jsonName string) { + t.seenTypes[nodeType] = struct{}{} + t.debugTrace.PushBack(jsonName) +} + +func (t *tracker) undoStep(nodeType reflect.Type) { + back := t.debugTrace.Back() + t.debugTrace.Remove(back) + delete(t.seenTypes, nodeType) +}