mirror of https://github.com/databricks/cli.git
Use tracker for reference loop tracking (#252)
We incorrectly relied on map key iteration order to print debug trace. This PR switches over to using the tracker struct to allow more reliable json schema reference loop detection and logging This also fixes the failing TestSelfReferenceLoopErrors and TestCrossReferenceLoopErrors tests
This commit is contained in:
parent
207777849b
commit
7faa9dea9b
|
@ -65,20 +65,19 @@ func (reader *OpenapiReader) readOpenapiSchema(path string) (*Schema, error) {
|
|||
}
|
||||
|
||||
// safe againt loops in refs
|
||||
func (reader *OpenapiReader) safeResolveRefs(root *Schema, seenRefs map[string]struct{}) (*Schema, error) {
|
||||
func (reader *OpenapiReader) safeResolveRefs(root *Schema, tracker *tracker) (*Schema, error) {
|
||||
if root.Reference == nil {
|
||||
return reader.traverseSchema(root, seenRefs)
|
||||
return reader.traverseSchema(root, tracker)
|
||||
}
|
||||
key := *root.Reference
|
||||
_, ok := seenRefs[key]
|
||||
if ok {
|
||||
if tracker.hasCycle(key) {
|
||||
// self reference loops can be supported however the logic is non-trivial because
|
||||
// cross refernce loops are not allowed (see: http://json-schema.org/understanding-json-schema/structuring.html#recursion)
|
||||
return nil, fmt.Errorf("references loop detected")
|
||||
}
|
||||
ref := *root.Reference
|
||||
description := root.Description
|
||||
seenRefs[ref] = struct{}{}
|
||||
tracker.push(ref, ref)
|
||||
|
||||
// Mark reference nil, so we do not traverse this again. This is tracked
|
||||
// in the memo
|
||||
|
@ -93,27 +92,27 @@ func (reader *OpenapiReader) safeResolveRefs(root *Schema, seenRefs map[string]s
|
|||
root.Description = description
|
||||
|
||||
// traverse again to find new references
|
||||
root, err = reader.traverseSchema(root, seenRefs)
|
||||
root, err = reader.traverseSchema(root, tracker)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
delete(seenRefs, ref)
|
||||
tracker.pop(ref)
|
||||
return root, err
|
||||
}
|
||||
|
||||
func (reader *OpenapiReader) traverseSchema(root *Schema, seenRefs map[string]struct{}) (*Schema, error) {
|
||||
func (reader *OpenapiReader) traverseSchema(root *Schema, tracker *tracker) (*Schema, error) {
|
||||
// case primitive (or invalid)
|
||||
if root.Type != Object && root.Type != Array {
|
||||
return root, nil
|
||||
}
|
||||
// only root references are resolved
|
||||
if root.Reference != nil {
|
||||
return reader.safeResolveRefs(root, seenRefs)
|
||||
return reader.safeResolveRefs(root, tracker)
|
||||
}
|
||||
// case struct
|
||||
if len(root.Properties) > 0 {
|
||||
for k, v := range root.Properties {
|
||||
childSchema, err := reader.safeResolveRefs(v, seenRefs)
|
||||
childSchema, err := reader.safeResolveRefs(v, tracker)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -122,7 +121,7 @@ func (reader *OpenapiReader) traverseSchema(root *Schema, seenRefs map[string]st
|
|||
}
|
||||
// case array
|
||||
if root.Items != nil {
|
||||
itemsSchema, err := reader.safeResolveRefs(root.Items, seenRefs)
|
||||
itemsSchema, err := reader.safeResolveRefs(root.Items, tracker)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -131,7 +130,7 @@ func (reader *OpenapiReader) traverseSchema(root *Schema, seenRefs map[string]st
|
|||
// case map
|
||||
additionionalProperties, ok := root.AdditionalProperties.(*Schema)
|
||||
if ok && additionionalProperties != nil {
|
||||
valueSchema, err := reader.safeResolveRefs(additionionalProperties, seenRefs)
|
||||
valueSchema, err := reader.safeResolveRefs(additionionalProperties, tracker)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -145,21 +144,11 @@ func (reader *OpenapiReader) readResolvedSchema(path string) (*Schema, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
seenRefs := make(map[string]struct{})
|
||||
seenRefs[path] = struct{}{}
|
||||
root, err = reader.safeResolveRefs(root, seenRefs)
|
||||
tracker := newTracker()
|
||||
tracker.push(path, path)
|
||||
root, err = reader.safeResolveRefs(root, tracker)
|
||||
if err != nil {
|
||||
trace := ""
|
||||
count := 0
|
||||
for k := range seenRefs {
|
||||
if count == len(seenRefs)-1 {
|
||||
trace += k
|
||||
break
|
||||
}
|
||||
trace += k + " -> "
|
||||
count++
|
||||
}
|
||||
return nil, fmt.Errorf("%s. schema ref trace: %s", err, trace)
|
||||
return nil, tracker.errWithTrace(err.Error(), "")
|
||||
}
|
||||
return root, nil
|
||||
}
|
||||
|
|
|
@ -226,7 +226,6 @@ func TestRootReferenceIsResolved(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSelfReferenceLoopErrors(t *testing.T) {
|
||||
t.Skip()
|
||||
specString := `{
|
||||
"components": {
|
||||
"schemas": {
|
||||
|
@ -257,11 +256,10 @@ func TestSelfReferenceLoopErrors(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
_, err = reader.readResolvedSchema("#/components/schemas/fruits")
|
||||
assert.ErrorContains(t, err, "references loop detected. schema ref trace: #/components/schemas/fruits -> #/components/schemas/foo")
|
||||
assert.ErrorContains(t, err, "references loop detected. traversal trace: -> #/components/schemas/fruits -> #/components/schemas/foo")
|
||||
}
|
||||
|
||||
func TestCrossReferenceLoopErrors(t *testing.T) {
|
||||
t.Skip()
|
||||
specString := `{
|
||||
"components": {
|
||||
"schemas": {
|
||||
|
@ -292,7 +290,7 @@ func TestCrossReferenceLoopErrors(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
_, err = reader.readResolvedSchema("#/components/schemas/fruits")
|
||||
assert.ErrorContains(t, err, "references loop detected. schema ref trace: #/components/schemas/fruits -> #/components/schemas/foo")
|
||||
assert.ErrorContains(t, err, "references loop detected. traversal trace: -> #/components/schemas/fruits -> #/components/schemas/foo")
|
||||
}
|
||||
|
||||
func TestReferenceResolutionForMapInObject(t *testing.T) {
|
||||
|
|
|
@ -65,7 +65,7 @@ func New(golangType reflect.Type, docs *Docs) (*Schema, error) {
|
|||
tracker := newTracker()
|
||||
schema, err := safeToSchema(golangType, docs, "", tracker)
|
||||
if err != nil {
|
||||
return nil, tracker.errWithTrace(err.Error())
|
||||
return nil, tracker.errWithTrace(err.Error(), "root")
|
||||
}
|
||||
return schema, nil
|
||||
}
|
||||
|
|
|
@ -662,17 +662,17 @@ func TestEmbeddedStructSchema(t *testing.T) {
|
|||
func TestErrorWithTrace(t *testing.T) {
|
||||
tracker := newTracker()
|
||||
dummyType := reflect.TypeOf(struct{}{})
|
||||
err := tracker.errWithTrace("with empty trace")
|
||||
assert.ErrorContains(t, err, "[ERROR] with empty trace. traversal trace: root")
|
||||
err := tracker.errWithTrace("with empty trace", "root")
|
||||
assert.ErrorContains(t, err, "with empty trace. traversal trace: root")
|
||||
|
||||
tracker.push(dummyType, "resources")
|
||||
err = tracker.errWithTrace("with depth = 1")
|
||||
assert.ErrorContains(t, err, "[ERROR] with depth = 1. traversal trace: root -> resources")
|
||||
err = tracker.errWithTrace("with depth = 1", "root")
|
||||
assert.ErrorContains(t, err, "with depth = 1. traversal trace: root -> resources")
|
||||
|
||||
tracker.push(dummyType, "pipelines")
|
||||
tracker.push(dummyType, "datasets")
|
||||
err = tracker.errWithTrace("with depth = 4")
|
||||
assert.ErrorContains(t, err, "[ERROR] with depth = 4. traversal trace: root -> resources -> pipelines -> datasets")
|
||||
err = tracker.errWithTrace("with depth = 4", "root")
|
||||
assert.ErrorContains(t, err, "with depth = 4. traversal trace: root -> resources -> pipelines -> datasets")
|
||||
}
|
||||
|
||||
func TestNonAnnotatedFieldsAreSkipped(t *testing.T) {
|
||||
|
@ -1356,7 +1356,7 @@ func TestErrorIfStructRefersToItself(t *testing.T) {
|
|||
|
||||
elem := Foo{}
|
||||
_, err := New(reflect.TypeOf(elem), nil)
|
||||
assert.ErrorContains(t, err, "ERROR] cycle detected. traversal trace: root -> my_foo")
|
||||
assert.ErrorContains(t, err, "cycle detected. traversal trace: root -> my_foo")
|
||||
}
|
||||
|
||||
func TestErrorIfStructHasLoop(t *testing.T) {
|
||||
|
@ -1373,7 +1373,7 @@ func TestErrorIfStructHasLoop(t *testing.T) {
|
|||
|
||||
elem := Apple{}
|
||||
_, err := New(reflect.TypeOf(elem), nil)
|
||||
assert.ErrorContains(t, err, "[ERROR] cycle detected. traversal trace: root -> my_mango -> my_guava -> my_papaya -> my_apple")
|
||||
assert.ErrorContains(t, err, "cycle detected. traversal trace: root -> my_mango -> my_guava -> my_papaya -> my_apple")
|
||||
}
|
||||
|
||||
func TestInterfaceGeneratesEmptySchema(t *testing.T) {
|
||||
|
|
|
@ -3,52 +3,51 @@ package schema
|
|||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type tracker struct {
|
||||
// Types encountered in current path during the recursive traversal. Used to
|
||||
// Nodes encountered in current path during the recursive traversal. Used to
|
||||
// check for cycles
|
||||
seenTypes map[reflect.Type]struct{}
|
||||
seenNodes map[interface{}]struct{}
|
||||
|
||||
// List of field names encountered in current path during the recursive traversal.
|
||||
// List of node names encountered in order in current path during the recursive traversal.
|
||||
// Used to hydrate errors with path to the exact node where error occured.
|
||||
//
|
||||
// The field names here are the first tag in the json tags of struct field.
|
||||
debugTrace *list.List
|
||||
// NOTE: node and node names can be the same
|
||||
listOfNodes *list.List
|
||||
}
|
||||
|
||||
func newTracker() *tracker {
|
||||
return &tracker{
|
||||
seenTypes: map[reflect.Type]struct{}{},
|
||||
debugTrace: list.New(),
|
||||
seenNodes: map[interface{}]struct{}{},
|
||||
listOfNodes: list.New(),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tracker) errWithTrace(prefix string) error {
|
||||
traceString := "root"
|
||||
curr := t.debugTrace.Front()
|
||||
func (t *tracker) errWithTrace(prefix string, initTrace string) error {
|
||||
traceString := initTrace
|
||||
curr := t.listOfNodes.Front()
|
||||
for curr != nil {
|
||||
if curr.Value.(string) != "" {
|
||||
traceString += " -> " + curr.Value.(string)
|
||||
}
|
||||
curr = curr.Next()
|
||||
}
|
||||
return fmt.Errorf("[ERROR] " + prefix + ". traversal trace: " + traceString)
|
||||
return fmt.Errorf(prefix + ". traversal trace: " + traceString)
|
||||
}
|
||||
|
||||
func (t *tracker) hasCycle(golangType reflect.Type) bool {
|
||||
_, ok := t.seenTypes[golangType]
|
||||
func (t *tracker) hasCycle(node interface{}) bool {
|
||||
_, ok := t.seenNodes[node]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (t *tracker) push(nodeType reflect.Type, jsonName string) {
|
||||
t.seenTypes[nodeType] = struct{}{}
|
||||
t.debugTrace.PushBack(jsonName)
|
||||
func (t *tracker) push(node interface{}, name string) {
|
||||
t.seenNodes[node] = struct{}{}
|
||||
t.listOfNodes.PushBack(name)
|
||||
}
|
||||
|
||||
func (t *tracker) pop(nodeType reflect.Type) {
|
||||
back := t.debugTrace.Back()
|
||||
t.debugTrace.Remove(back)
|
||||
delete(t.seenTypes, nodeType)
|
||||
func (t *tracker) pop(nodeType interface{}) {
|
||||
back := t.listOfNodes.Back()
|
||||
t.listOfNodes.Remove(back)
|
||||
delete(t.seenNodes, nodeType)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue