Use tracker for reference loop tracking (#252)

We incorrectly relied on map key iteration order to print debug trace.
This PR switches over to using the tracker struct to allow more reliable
json schema reference loop detection and logging

This also fixes the failing TestSelfReferenceLoopErrors and
TestCrossReferenceLoopErrors tests
This commit is contained in:
shreyas-goenka 2023-03-16 12:57:57 +01:00 committed by GitHub
parent 207777849b
commit 7faa9dea9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 46 additions and 60 deletions

View File

@ -65,20 +65,19 @@ func (reader *OpenapiReader) readOpenapiSchema(path string) (*Schema, error) {
}
// safe againt loops in refs
func (reader *OpenapiReader) safeResolveRefs(root *Schema, seenRefs map[string]struct{}) (*Schema, error) {
func (reader *OpenapiReader) safeResolveRefs(root *Schema, tracker *tracker) (*Schema, error) {
if root.Reference == nil {
return reader.traverseSchema(root, seenRefs)
return reader.traverseSchema(root, tracker)
}
key := *root.Reference
_, ok := seenRefs[key]
if ok {
if tracker.hasCycle(key) {
// self reference loops can be supported however the logic is non-trivial because
// cross refernce loops are not allowed (see: http://json-schema.org/understanding-json-schema/structuring.html#recursion)
return nil, fmt.Errorf("references loop detected")
}
ref := *root.Reference
description := root.Description
seenRefs[ref] = struct{}{}
tracker.push(ref, ref)
// Mark reference nil, so we do not traverse this again. This is tracked
// in the memo
@ -93,27 +92,27 @@ func (reader *OpenapiReader) safeResolveRefs(root *Schema, seenRefs map[string]s
root.Description = description
// traverse again to find new references
root, err = reader.traverseSchema(root, seenRefs)
root, err = reader.traverseSchema(root, tracker)
if err != nil {
return nil, err
}
delete(seenRefs, ref)
tracker.pop(ref)
return root, err
}
func (reader *OpenapiReader) traverseSchema(root *Schema, seenRefs map[string]struct{}) (*Schema, error) {
func (reader *OpenapiReader) traverseSchema(root *Schema, tracker *tracker) (*Schema, error) {
// case primitive (or invalid)
if root.Type != Object && root.Type != Array {
return root, nil
}
// only root references are resolved
if root.Reference != nil {
return reader.safeResolveRefs(root, seenRefs)
return reader.safeResolveRefs(root, tracker)
}
// case struct
if len(root.Properties) > 0 {
for k, v := range root.Properties {
childSchema, err := reader.safeResolveRefs(v, seenRefs)
childSchema, err := reader.safeResolveRefs(v, tracker)
if err != nil {
return nil, err
}
@ -122,7 +121,7 @@ func (reader *OpenapiReader) traverseSchema(root *Schema, seenRefs map[string]st
}
// case array
if root.Items != nil {
itemsSchema, err := reader.safeResolveRefs(root.Items, seenRefs)
itemsSchema, err := reader.safeResolveRefs(root.Items, tracker)
if err != nil {
return nil, err
}
@ -131,7 +130,7 @@ func (reader *OpenapiReader) traverseSchema(root *Schema, seenRefs map[string]st
// case map
additionionalProperties, ok := root.AdditionalProperties.(*Schema)
if ok && additionionalProperties != nil {
valueSchema, err := reader.safeResolveRefs(additionionalProperties, seenRefs)
valueSchema, err := reader.safeResolveRefs(additionionalProperties, tracker)
if err != nil {
return nil, err
}
@ -145,21 +144,11 @@ func (reader *OpenapiReader) readResolvedSchema(path string) (*Schema, error) {
if err != nil {
return nil, err
}
seenRefs := make(map[string]struct{})
seenRefs[path] = struct{}{}
root, err = reader.safeResolveRefs(root, seenRefs)
tracker := newTracker()
tracker.push(path, path)
root, err = reader.safeResolveRefs(root, tracker)
if err != nil {
trace := ""
count := 0
for k := range seenRefs {
if count == len(seenRefs)-1 {
trace += k
break
}
trace += k + " -> "
count++
}
return nil, fmt.Errorf("%s. schema ref trace: %s", err, trace)
return nil, tracker.errWithTrace(err.Error(), "")
}
return root, nil
}

View File

@ -226,7 +226,6 @@ func TestRootReferenceIsResolved(t *testing.T) {
}
func TestSelfReferenceLoopErrors(t *testing.T) {
t.Skip()
specString := `{
"components": {
"schemas": {
@ -257,11 +256,10 @@ func TestSelfReferenceLoopErrors(t *testing.T) {
require.NoError(t, err)
_, err = reader.readResolvedSchema("#/components/schemas/fruits")
assert.ErrorContains(t, err, "references loop detected. schema ref trace: #/components/schemas/fruits -> #/components/schemas/foo")
assert.ErrorContains(t, err, "references loop detected. traversal trace: -> #/components/schemas/fruits -> #/components/schemas/foo")
}
func TestCrossReferenceLoopErrors(t *testing.T) {
t.Skip()
specString := `{
"components": {
"schemas": {
@ -292,7 +290,7 @@ func TestCrossReferenceLoopErrors(t *testing.T) {
require.NoError(t, err)
_, err = reader.readResolvedSchema("#/components/schemas/fruits")
assert.ErrorContains(t, err, "references loop detected. schema ref trace: #/components/schemas/fruits -> #/components/schemas/foo")
assert.ErrorContains(t, err, "references loop detected. traversal trace: -> #/components/schemas/fruits -> #/components/schemas/foo")
}
func TestReferenceResolutionForMapInObject(t *testing.T) {

View File

@ -65,7 +65,7 @@ func New(golangType reflect.Type, docs *Docs) (*Schema, error) {
tracker := newTracker()
schema, err := safeToSchema(golangType, docs, "", tracker)
if err != nil {
return nil, tracker.errWithTrace(err.Error())
return nil, tracker.errWithTrace(err.Error(), "root")
}
return schema, nil
}

View File

@ -662,17 +662,17 @@ func TestEmbeddedStructSchema(t *testing.T) {
func TestErrorWithTrace(t *testing.T) {
tracker := newTracker()
dummyType := reflect.TypeOf(struct{}{})
err := tracker.errWithTrace("with empty trace")
assert.ErrorContains(t, err, "[ERROR] with empty trace. traversal trace: root")
err := tracker.errWithTrace("with empty trace", "root")
assert.ErrorContains(t, err, "with empty trace. traversal trace: root")
tracker.push(dummyType, "resources")
err = tracker.errWithTrace("with depth = 1")
assert.ErrorContains(t, err, "[ERROR] with depth = 1. traversal trace: root -> resources")
err = tracker.errWithTrace("with depth = 1", "root")
assert.ErrorContains(t, err, "with depth = 1. traversal trace: root -> resources")
tracker.push(dummyType, "pipelines")
tracker.push(dummyType, "datasets")
err = tracker.errWithTrace("with depth = 4")
assert.ErrorContains(t, err, "[ERROR] with depth = 4. traversal trace: root -> resources -> pipelines -> datasets")
err = tracker.errWithTrace("with depth = 4", "root")
assert.ErrorContains(t, err, "with depth = 4. traversal trace: root -> resources -> pipelines -> datasets")
}
func TestNonAnnotatedFieldsAreSkipped(t *testing.T) {
@ -1356,7 +1356,7 @@ func TestErrorIfStructRefersToItself(t *testing.T) {
elem := Foo{}
_, err := New(reflect.TypeOf(elem), nil)
assert.ErrorContains(t, err, "ERROR] cycle detected. traversal trace: root -> my_foo")
assert.ErrorContains(t, err, "cycle detected. traversal trace: root -> my_foo")
}
func TestErrorIfStructHasLoop(t *testing.T) {
@ -1373,7 +1373,7 @@ func TestErrorIfStructHasLoop(t *testing.T) {
elem := Apple{}
_, err := New(reflect.TypeOf(elem), nil)
assert.ErrorContains(t, err, "[ERROR] cycle detected. traversal trace: root -> my_mango -> my_guava -> my_papaya -> my_apple")
assert.ErrorContains(t, err, "cycle detected. traversal trace: root -> my_mango -> my_guava -> my_papaya -> my_apple")
}
func TestInterfaceGeneratesEmptySchema(t *testing.T) {

View File

@ -3,52 +3,51 @@ package schema
import (
"container/list"
"fmt"
"reflect"
)
type tracker struct {
// Types encountered in current path during the recursive traversal. Used to
// Nodes encountered in current path during the recursive traversal. Used to
// check for cycles
seenTypes map[reflect.Type]struct{}
seenNodes map[interface{}]struct{}
// List of field names encountered in current path during the recursive traversal.
// List of node names encountered in order in current path during the recursive traversal.
// Used to hydrate errors with path to the exact node where error occured.
//
// The field names here are the first tag in the json tags of struct field.
debugTrace *list.List
// NOTE: node and node names can be the same
listOfNodes *list.List
}
func newTracker() *tracker {
return &tracker{
seenTypes: map[reflect.Type]struct{}{},
debugTrace: list.New(),
seenNodes: map[interface{}]struct{}{},
listOfNodes: list.New(),
}
}
func (t *tracker) errWithTrace(prefix string) error {
traceString := "root"
curr := t.debugTrace.Front()
func (t *tracker) errWithTrace(prefix string, initTrace string) error {
traceString := initTrace
curr := t.listOfNodes.Front()
for curr != nil {
if curr.Value.(string) != "" {
traceString += " -> " + curr.Value.(string)
}
curr = curr.Next()
}
return fmt.Errorf("[ERROR] " + prefix + ". traversal trace: " + traceString)
return fmt.Errorf(prefix + ". traversal trace: " + traceString)
}
func (t *tracker) hasCycle(golangType reflect.Type) bool {
_, ok := t.seenTypes[golangType]
func (t *tracker) hasCycle(node interface{}) bool {
_, ok := t.seenNodes[node]
return ok
}
func (t *tracker) push(nodeType reflect.Type, jsonName string) {
t.seenTypes[nodeType] = struct{}{}
t.debugTrace.PushBack(jsonName)
func (t *tracker) push(node interface{}, name string) {
t.seenNodes[node] = struct{}{}
t.listOfNodes.PushBack(name)
}
func (t *tracker) pop(nodeType reflect.Type) {
back := t.debugTrace.Back()
t.debugTrace.Remove(back)
delete(t.seenTypes, nodeType)
func (t *tracker) pop(nodeType interface{}) {
back := t.listOfNodes.Back()
t.listOfNodes.Remove(back)
delete(t.seenNodes, nodeType)
}