Use tracker for reference loop tracking (#252)

We incorrectly relied on map key iteration order to print debug trace.
This PR switches over to using the tracker struct to allow more reliable
json schema reference loop detection and logging

This also fixes the failing TestSelfReferenceLoopErrors and
TestCrossReferenceLoopErrors tests
This commit is contained in:
shreyas-goenka 2023-03-16 12:57:57 +01:00 committed by GitHub
parent 207777849b
commit 7faa9dea9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 46 additions and 60 deletions

View File

@ -65,20 +65,19 @@ func (reader *OpenapiReader) readOpenapiSchema(path string) (*Schema, error) {
} }
// safe againt loops in refs // safe againt loops in refs
func (reader *OpenapiReader) safeResolveRefs(root *Schema, seenRefs map[string]struct{}) (*Schema, error) { func (reader *OpenapiReader) safeResolveRefs(root *Schema, tracker *tracker) (*Schema, error) {
if root.Reference == nil { if root.Reference == nil {
return reader.traverseSchema(root, seenRefs) return reader.traverseSchema(root, tracker)
} }
key := *root.Reference key := *root.Reference
_, ok := seenRefs[key] if tracker.hasCycle(key) {
if ok {
// self reference loops can be supported however the logic is non-trivial because // self reference loops can be supported however the logic is non-trivial because
// cross refernce loops are not allowed (see: http://json-schema.org/understanding-json-schema/structuring.html#recursion) // cross refernce loops are not allowed (see: http://json-schema.org/understanding-json-schema/structuring.html#recursion)
return nil, fmt.Errorf("references loop detected") return nil, fmt.Errorf("references loop detected")
} }
ref := *root.Reference ref := *root.Reference
description := root.Description description := root.Description
seenRefs[ref] = struct{}{} tracker.push(ref, ref)
// Mark reference nil, so we do not traverse this again. This is tracked // Mark reference nil, so we do not traverse this again. This is tracked
// in the memo // in the memo
@ -93,27 +92,27 @@ func (reader *OpenapiReader) safeResolveRefs(root *Schema, seenRefs map[string]s
root.Description = description root.Description = description
// traverse again to find new references // traverse again to find new references
root, err = reader.traverseSchema(root, seenRefs) root, err = reader.traverseSchema(root, tracker)
if err != nil { if err != nil {
return nil, err return nil, err
} }
delete(seenRefs, ref) tracker.pop(ref)
return root, err return root, err
} }
func (reader *OpenapiReader) traverseSchema(root *Schema, seenRefs map[string]struct{}) (*Schema, error) { func (reader *OpenapiReader) traverseSchema(root *Schema, tracker *tracker) (*Schema, error) {
// case primitive (or invalid) // case primitive (or invalid)
if root.Type != Object && root.Type != Array { if root.Type != Object && root.Type != Array {
return root, nil return root, nil
} }
// only root references are resolved // only root references are resolved
if root.Reference != nil { if root.Reference != nil {
return reader.safeResolveRefs(root, seenRefs) return reader.safeResolveRefs(root, tracker)
} }
// case struct // case struct
if len(root.Properties) > 0 { if len(root.Properties) > 0 {
for k, v := range root.Properties { for k, v := range root.Properties {
childSchema, err := reader.safeResolveRefs(v, seenRefs) childSchema, err := reader.safeResolveRefs(v, tracker)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -122,7 +121,7 @@ func (reader *OpenapiReader) traverseSchema(root *Schema, seenRefs map[string]st
} }
// case array // case array
if root.Items != nil { if root.Items != nil {
itemsSchema, err := reader.safeResolveRefs(root.Items, seenRefs) itemsSchema, err := reader.safeResolveRefs(root.Items, tracker)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -131,7 +130,7 @@ func (reader *OpenapiReader) traverseSchema(root *Schema, seenRefs map[string]st
// case map // case map
additionionalProperties, ok := root.AdditionalProperties.(*Schema) additionionalProperties, ok := root.AdditionalProperties.(*Schema)
if ok && additionionalProperties != nil { if ok && additionionalProperties != nil {
valueSchema, err := reader.safeResolveRefs(additionionalProperties, seenRefs) valueSchema, err := reader.safeResolveRefs(additionionalProperties, tracker)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -145,21 +144,11 @@ func (reader *OpenapiReader) readResolvedSchema(path string) (*Schema, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
seenRefs := make(map[string]struct{}) tracker := newTracker()
seenRefs[path] = struct{}{} tracker.push(path, path)
root, err = reader.safeResolveRefs(root, seenRefs) root, err = reader.safeResolveRefs(root, tracker)
if err != nil { if err != nil {
trace := "" return nil, tracker.errWithTrace(err.Error(), "")
count := 0
for k := range seenRefs {
if count == len(seenRefs)-1 {
trace += k
break
}
trace += k + " -> "
count++
}
return nil, fmt.Errorf("%s. schema ref trace: %s", err, trace)
} }
return root, nil return root, nil
} }

View File

@ -226,7 +226,6 @@ func TestRootReferenceIsResolved(t *testing.T) {
} }
func TestSelfReferenceLoopErrors(t *testing.T) { func TestSelfReferenceLoopErrors(t *testing.T) {
t.Skip()
specString := `{ specString := `{
"components": { "components": {
"schemas": { "schemas": {
@ -257,11 +256,10 @@ func TestSelfReferenceLoopErrors(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, err = reader.readResolvedSchema("#/components/schemas/fruits") _, err = reader.readResolvedSchema("#/components/schemas/fruits")
assert.ErrorContains(t, err, "references loop detected. schema ref trace: #/components/schemas/fruits -> #/components/schemas/foo") assert.ErrorContains(t, err, "references loop detected. traversal trace: -> #/components/schemas/fruits -> #/components/schemas/foo")
} }
func TestCrossReferenceLoopErrors(t *testing.T) { func TestCrossReferenceLoopErrors(t *testing.T) {
t.Skip()
specString := `{ specString := `{
"components": { "components": {
"schemas": { "schemas": {
@ -292,7 +290,7 @@ func TestCrossReferenceLoopErrors(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, err = reader.readResolvedSchema("#/components/schemas/fruits") _, err = reader.readResolvedSchema("#/components/schemas/fruits")
assert.ErrorContains(t, err, "references loop detected. schema ref trace: #/components/schemas/fruits -> #/components/schemas/foo") assert.ErrorContains(t, err, "references loop detected. traversal trace: -> #/components/schemas/fruits -> #/components/schemas/foo")
} }
func TestReferenceResolutionForMapInObject(t *testing.T) { func TestReferenceResolutionForMapInObject(t *testing.T) {

View File

@ -65,7 +65,7 @@ func New(golangType reflect.Type, docs *Docs) (*Schema, error) {
tracker := newTracker() tracker := newTracker()
schema, err := safeToSchema(golangType, docs, "", tracker) schema, err := safeToSchema(golangType, docs, "", tracker)
if err != nil { if err != nil {
return nil, tracker.errWithTrace(err.Error()) return nil, tracker.errWithTrace(err.Error(), "root")
} }
return schema, nil return schema, nil
} }

View File

@ -662,17 +662,17 @@ func TestEmbeddedStructSchema(t *testing.T) {
func TestErrorWithTrace(t *testing.T) { func TestErrorWithTrace(t *testing.T) {
tracker := newTracker() tracker := newTracker()
dummyType := reflect.TypeOf(struct{}{}) dummyType := reflect.TypeOf(struct{}{})
err := tracker.errWithTrace("with empty trace") err := tracker.errWithTrace("with empty trace", "root")
assert.ErrorContains(t, err, "[ERROR] with empty trace. traversal trace: root") assert.ErrorContains(t, err, "with empty trace. traversal trace: root")
tracker.push(dummyType, "resources") tracker.push(dummyType, "resources")
err = tracker.errWithTrace("with depth = 1") err = tracker.errWithTrace("with depth = 1", "root")
assert.ErrorContains(t, err, "[ERROR] with depth = 1. traversal trace: root -> resources") assert.ErrorContains(t, err, "with depth = 1. traversal trace: root -> resources")
tracker.push(dummyType, "pipelines") tracker.push(dummyType, "pipelines")
tracker.push(dummyType, "datasets") tracker.push(dummyType, "datasets")
err = tracker.errWithTrace("with depth = 4") err = tracker.errWithTrace("with depth = 4", "root")
assert.ErrorContains(t, err, "[ERROR] with depth = 4. traversal trace: root -> resources -> pipelines -> datasets") assert.ErrorContains(t, err, "with depth = 4. traversal trace: root -> resources -> pipelines -> datasets")
} }
func TestNonAnnotatedFieldsAreSkipped(t *testing.T) { func TestNonAnnotatedFieldsAreSkipped(t *testing.T) {
@ -1356,7 +1356,7 @@ func TestErrorIfStructRefersToItself(t *testing.T) {
elem := Foo{} elem := Foo{}
_, err := New(reflect.TypeOf(elem), nil) _, err := New(reflect.TypeOf(elem), nil)
assert.ErrorContains(t, err, "ERROR] cycle detected. traversal trace: root -> my_foo") assert.ErrorContains(t, err, "cycle detected. traversal trace: root -> my_foo")
} }
func TestErrorIfStructHasLoop(t *testing.T) { func TestErrorIfStructHasLoop(t *testing.T) {
@ -1373,7 +1373,7 @@ func TestErrorIfStructHasLoop(t *testing.T) {
elem := Apple{} elem := Apple{}
_, err := New(reflect.TypeOf(elem), nil) _, err := New(reflect.TypeOf(elem), nil)
assert.ErrorContains(t, err, "[ERROR] cycle detected. traversal trace: root -> my_mango -> my_guava -> my_papaya -> my_apple") assert.ErrorContains(t, err, "cycle detected. traversal trace: root -> my_mango -> my_guava -> my_papaya -> my_apple")
} }
func TestInterfaceGeneratesEmptySchema(t *testing.T) { func TestInterfaceGeneratesEmptySchema(t *testing.T) {

View File

@ -3,52 +3,51 @@ package schema
import ( import (
"container/list" "container/list"
"fmt" "fmt"
"reflect"
) )
type tracker struct { type tracker struct {
// Types encountered in current path during the recursive traversal. Used to // Nodes encountered in current path during the recursive traversal. Used to
// check for cycles // check for cycles
seenTypes map[reflect.Type]struct{} seenNodes map[interface{}]struct{}
// List of field names encountered in current path during the recursive traversal. // List of node names encountered in order in current path during the recursive traversal.
// Used to hydrate errors with path to the exact node where error occured. // Used to hydrate errors with path to the exact node where error occured.
// //
// The field names here are the first tag in the json tags of struct field. // NOTE: node and node names can be the same
debugTrace *list.List listOfNodes *list.List
} }
func newTracker() *tracker { func newTracker() *tracker {
return &tracker{ return &tracker{
seenTypes: map[reflect.Type]struct{}{}, seenNodes: map[interface{}]struct{}{},
debugTrace: list.New(), listOfNodes: list.New(),
} }
} }
func (t *tracker) errWithTrace(prefix string) error { func (t *tracker) errWithTrace(prefix string, initTrace string) error {
traceString := "root" traceString := initTrace
curr := t.debugTrace.Front() curr := t.listOfNodes.Front()
for curr != nil { for curr != nil {
if curr.Value.(string) != "" { if curr.Value.(string) != "" {
traceString += " -> " + curr.Value.(string) traceString += " -> " + curr.Value.(string)
} }
curr = curr.Next() curr = curr.Next()
} }
return fmt.Errorf("[ERROR] " + prefix + ". traversal trace: " + traceString) return fmt.Errorf(prefix + ". traversal trace: " + traceString)
} }
func (t *tracker) hasCycle(golangType reflect.Type) bool { func (t *tracker) hasCycle(node interface{}) bool {
_, ok := t.seenTypes[golangType] _, ok := t.seenNodes[node]
return ok return ok
} }
func (t *tracker) push(nodeType reflect.Type, jsonName string) { func (t *tracker) push(node interface{}, name string) {
t.seenTypes[nodeType] = struct{}{} t.seenNodes[node] = struct{}{}
t.debugTrace.PushBack(jsonName) t.listOfNodes.PushBack(name)
} }
func (t *tracker) pop(nodeType reflect.Type) { func (t *tracker) pop(nodeType interface{}) {
back := t.debugTrace.Back() back := t.listOfNodes.Back()
t.debugTrace.Remove(back) t.listOfNodes.Remove(back)
delete(t.seenTypes, nodeType) delete(t.seenNodes, nodeType)
} }