diff --git a/bundle/schema/openapi.go b/bundle/schema/openapi.go index 97ec5141..32da7fdc 100644 --- a/bundle/schema/openapi.go +++ b/bundle/schema/openapi.go @@ -65,20 +65,19 @@ func (reader *OpenapiReader) readOpenapiSchema(path string) (*Schema, error) { } // safe againt loops in refs -func (reader *OpenapiReader) safeResolveRefs(root *Schema, seenRefs map[string]struct{}) (*Schema, error) { +func (reader *OpenapiReader) safeResolveRefs(root *Schema, tracker *tracker) (*Schema, error) { if root.Reference == nil { - return reader.traverseSchema(root, seenRefs) + return reader.traverseSchema(root, tracker) } key := *root.Reference - _, ok := seenRefs[key] - if ok { + if tracker.hasCycle(key) { // self reference loops can be supported however the logic is non-trivial because // cross refernce loops are not allowed (see: http://json-schema.org/understanding-json-schema/structuring.html#recursion) return nil, fmt.Errorf("references loop detected") } ref := *root.Reference description := root.Description - seenRefs[ref] = struct{}{} + tracker.push(ref, ref) // Mark reference nil, so we do not traverse this again. This is tracked // in the memo @@ -93,27 +92,27 @@ func (reader *OpenapiReader) safeResolveRefs(root *Schema, seenRefs map[string]s root.Description = description // traverse again to find new references - root, err = reader.traverseSchema(root, seenRefs) + root, err = reader.traverseSchema(root, tracker) if err != nil { return nil, err } - delete(seenRefs, ref) + tracker.pop(ref) return root, err } -func (reader *OpenapiReader) traverseSchema(root *Schema, seenRefs map[string]struct{}) (*Schema, error) { +func (reader *OpenapiReader) traverseSchema(root *Schema, tracker *tracker) (*Schema, error) { // case primitive (or invalid) if root.Type != Object && root.Type != Array { return root, nil } // only root references are resolved if root.Reference != nil { - return reader.safeResolveRefs(root, seenRefs) + return reader.safeResolveRefs(root, tracker) } // case struct if len(root.Properties) > 0 { for k, v := range root.Properties { - childSchema, err := reader.safeResolveRefs(v, seenRefs) + childSchema, err := reader.safeResolveRefs(v, tracker) if err != nil { return nil, err } @@ -122,7 +121,7 @@ func (reader *OpenapiReader) traverseSchema(root *Schema, seenRefs map[string]st } // case array if root.Items != nil { - itemsSchema, err := reader.safeResolveRefs(root.Items, seenRefs) + itemsSchema, err := reader.safeResolveRefs(root.Items, tracker) if err != nil { return nil, err } @@ -131,7 +130,7 @@ func (reader *OpenapiReader) traverseSchema(root *Schema, seenRefs map[string]st // case map additionionalProperties, ok := root.AdditionalProperties.(*Schema) if ok && additionionalProperties != nil { - valueSchema, err := reader.safeResolveRefs(additionionalProperties, seenRefs) + valueSchema, err := reader.safeResolveRefs(additionionalProperties, tracker) if err != nil { return nil, err } @@ -145,21 +144,11 @@ func (reader *OpenapiReader) readResolvedSchema(path string) (*Schema, error) { if err != nil { return nil, err } - seenRefs := make(map[string]struct{}) - seenRefs[path] = struct{}{} - root, err = reader.safeResolveRefs(root, seenRefs) + tracker := newTracker() + tracker.push(path, path) + root, err = reader.safeResolveRefs(root, tracker) if err != nil { - trace := "" - count := 0 - for k := range seenRefs { - if count == len(seenRefs)-1 { - trace += k - break - } - trace += k + " -> " - count++ - } - return nil, fmt.Errorf("%s. schema ref trace: %s", err, trace) + return nil, tracker.errWithTrace(err.Error(), "") } return root, nil } diff --git a/bundle/schema/openapi_test.go b/bundle/schema/openapi_test.go index 3e9d2533..282fac8d 100644 --- a/bundle/schema/openapi_test.go +++ b/bundle/schema/openapi_test.go @@ -226,7 +226,6 @@ func TestRootReferenceIsResolved(t *testing.T) { } func TestSelfReferenceLoopErrors(t *testing.T) { - t.Skip() specString := `{ "components": { "schemas": { @@ -257,11 +256,10 @@ func TestSelfReferenceLoopErrors(t *testing.T) { require.NoError(t, err) _, err = reader.readResolvedSchema("#/components/schemas/fruits") - assert.ErrorContains(t, err, "references loop detected. schema ref trace: #/components/schemas/fruits -> #/components/schemas/foo") + assert.ErrorContains(t, err, "references loop detected. traversal trace: -> #/components/schemas/fruits -> #/components/schemas/foo") } func TestCrossReferenceLoopErrors(t *testing.T) { - t.Skip() specString := `{ "components": { "schemas": { @@ -292,7 +290,7 @@ func TestCrossReferenceLoopErrors(t *testing.T) { require.NoError(t, err) _, err = reader.readResolvedSchema("#/components/schemas/fruits") - assert.ErrorContains(t, err, "references loop detected. schema ref trace: #/components/schemas/fruits -> #/components/schemas/foo") + assert.ErrorContains(t, err, "references loop detected. traversal trace: -> #/components/schemas/fruits -> #/components/schemas/foo") } func TestReferenceResolutionForMapInObject(t *testing.T) { diff --git a/bundle/schema/schema.go b/bundle/schema/schema.go index 2e50298c..af3dfd76 100644 --- a/bundle/schema/schema.go +++ b/bundle/schema/schema.go @@ -65,7 +65,7 @@ func New(golangType reflect.Type, docs *Docs) (*Schema, error) { tracker := newTracker() schema, err := safeToSchema(golangType, docs, "", tracker) if err != nil { - return nil, tracker.errWithTrace(err.Error()) + return nil, tracker.errWithTrace(err.Error(), "root") } return schema, nil } diff --git a/bundle/schema/schema_test.go b/bundle/schema/schema_test.go index 966aab81..edb56371 100644 --- a/bundle/schema/schema_test.go +++ b/bundle/schema/schema_test.go @@ -662,17 +662,17 @@ func TestEmbeddedStructSchema(t *testing.T) { func TestErrorWithTrace(t *testing.T) { tracker := newTracker() dummyType := reflect.TypeOf(struct{}{}) - err := tracker.errWithTrace("with empty trace") - assert.ErrorContains(t, err, "[ERROR] with empty trace. traversal trace: root") + err := tracker.errWithTrace("with empty trace", "root") + assert.ErrorContains(t, err, "with empty trace. traversal trace: root") tracker.push(dummyType, "resources") - err = tracker.errWithTrace("with depth = 1") - assert.ErrorContains(t, err, "[ERROR] with depth = 1. traversal trace: root -> resources") + err = tracker.errWithTrace("with depth = 1", "root") + assert.ErrorContains(t, err, "with depth = 1. traversal trace: root -> resources") tracker.push(dummyType, "pipelines") tracker.push(dummyType, "datasets") - err = tracker.errWithTrace("with depth = 4") - assert.ErrorContains(t, err, "[ERROR] with depth = 4. traversal trace: root -> resources -> pipelines -> datasets") + err = tracker.errWithTrace("with depth = 4", "root") + assert.ErrorContains(t, err, "with depth = 4. traversal trace: root -> resources -> pipelines -> datasets") } func TestNonAnnotatedFieldsAreSkipped(t *testing.T) { @@ -1356,7 +1356,7 @@ func TestErrorIfStructRefersToItself(t *testing.T) { elem := Foo{} _, err := New(reflect.TypeOf(elem), nil) - assert.ErrorContains(t, err, "ERROR] cycle detected. traversal trace: root -> my_foo") + assert.ErrorContains(t, err, "cycle detected. traversal trace: root -> my_foo") } func TestErrorIfStructHasLoop(t *testing.T) { @@ -1373,7 +1373,7 @@ func TestErrorIfStructHasLoop(t *testing.T) { elem := Apple{} _, err := New(reflect.TypeOf(elem), nil) - assert.ErrorContains(t, err, "[ERROR] cycle detected. traversal trace: root -> my_mango -> my_guava -> my_papaya -> my_apple") + assert.ErrorContains(t, err, "cycle detected. traversal trace: root -> my_mango -> my_guava -> my_papaya -> my_apple") } func TestInterfaceGeneratesEmptySchema(t *testing.T) { diff --git a/bundle/schema/tracker.go b/bundle/schema/tracker.go index 8dc4f3f4..ace6559b 100644 --- a/bundle/schema/tracker.go +++ b/bundle/schema/tracker.go @@ -3,52 +3,51 @@ package schema import ( "container/list" "fmt" - "reflect" ) type tracker struct { - // Types encountered in current path during the recursive traversal. Used to + // Nodes encountered in current path during the recursive traversal. Used to // check for cycles - seenTypes map[reflect.Type]struct{} + seenNodes map[interface{}]struct{} - // List of field names encountered in current path during the recursive traversal. + // List of node names encountered in order in current path during the recursive traversal. // Used to hydrate errors with path to the exact node where error occured. // - // The field names here are the first tag in the json tags of struct field. - debugTrace *list.List + // NOTE: node and node names can be the same + listOfNodes *list.List } func newTracker() *tracker { return &tracker{ - seenTypes: map[reflect.Type]struct{}{}, - debugTrace: list.New(), + seenNodes: map[interface{}]struct{}{}, + listOfNodes: list.New(), } } -func (t *tracker) errWithTrace(prefix string) error { - traceString := "root" - curr := t.debugTrace.Front() +func (t *tracker) errWithTrace(prefix string, initTrace string) error { + traceString := initTrace + curr := t.listOfNodes.Front() for curr != nil { if curr.Value.(string) != "" { traceString += " -> " + curr.Value.(string) } curr = curr.Next() } - return fmt.Errorf("[ERROR] " + prefix + ". traversal trace: " + traceString) + return fmt.Errorf(prefix + ". traversal trace: " + traceString) } -func (t *tracker) hasCycle(golangType reflect.Type) bool { - _, ok := t.seenTypes[golangType] +func (t *tracker) hasCycle(node interface{}) bool { + _, ok := t.seenNodes[node] return ok } -func (t *tracker) push(nodeType reflect.Type, jsonName string) { - t.seenTypes[nodeType] = struct{}{} - t.debugTrace.PushBack(jsonName) +func (t *tracker) push(node interface{}, name string) { + t.seenNodes[node] = struct{}{} + t.listOfNodes.PushBack(name) } -func (t *tracker) pop(nodeType reflect.Type) { - back := t.debugTrace.Back() - t.debugTrace.Remove(back) - delete(t.seenTypes, nodeType) +func (t *tracker) pop(nodeType interface{}) { + back := t.listOfNodes.Back() + t.listOfNodes.Remove(back) + delete(t.seenNodes, nodeType) }