mirror of https://github.com/databricks/cli.git
Use tracker for reference loop tracking (#252)
We incorrectly relied on map key iteration order to print debug trace. This PR switches over to using the tracker struct to allow more reliable json schema reference loop detection and logging This also fixes the failing TestSelfReferenceLoopErrors and TestCrossReferenceLoopErrors tests
This commit is contained in:
parent
207777849b
commit
7faa9dea9b
|
@ -65,20 +65,19 @@ func (reader *OpenapiReader) readOpenapiSchema(path string) (*Schema, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// safe againt loops in refs
|
// safe againt loops in refs
|
||||||
func (reader *OpenapiReader) safeResolveRefs(root *Schema, seenRefs map[string]struct{}) (*Schema, error) {
|
func (reader *OpenapiReader) safeResolveRefs(root *Schema, tracker *tracker) (*Schema, error) {
|
||||||
if root.Reference == nil {
|
if root.Reference == nil {
|
||||||
return reader.traverseSchema(root, seenRefs)
|
return reader.traverseSchema(root, tracker)
|
||||||
}
|
}
|
||||||
key := *root.Reference
|
key := *root.Reference
|
||||||
_, ok := seenRefs[key]
|
if tracker.hasCycle(key) {
|
||||||
if ok {
|
|
||||||
// self reference loops can be supported however the logic is non-trivial because
|
// self reference loops can be supported however the logic is non-trivial because
|
||||||
// cross refernce loops are not allowed (see: http://json-schema.org/understanding-json-schema/structuring.html#recursion)
|
// cross refernce loops are not allowed (see: http://json-schema.org/understanding-json-schema/structuring.html#recursion)
|
||||||
return nil, fmt.Errorf("references loop detected")
|
return nil, fmt.Errorf("references loop detected")
|
||||||
}
|
}
|
||||||
ref := *root.Reference
|
ref := *root.Reference
|
||||||
description := root.Description
|
description := root.Description
|
||||||
seenRefs[ref] = struct{}{}
|
tracker.push(ref, ref)
|
||||||
|
|
||||||
// Mark reference nil, so we do not traverse this again. This is tracked
|
// Mark reference nil, so we do not traverse this again. This is tracked
|
||||||
// in the memo
|
// in the memo
|
||||||
|
@ -93,27 +92,27 @@ func (reader *OpenapiReader) safeResolveRefs(root *Schema, seenRefs map[string]s
|
||||||
root.Description = description
|
root.Description = description
|
||||||
|
|
||||||
// traverse again to find new references
|
// traverse again to find new references
|
||||||
root, err = reader.traverseSchema(root, seenRefs)
|
root, err = reader.traverseSchema(root, tracker)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
delete(seenRefs, ref)
|
tracker.pop(ref)
|
||||||
return root, err
|
return root, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (reader *OpenapiReader) traverseSchema(root *Schema, seenRefs map[string]struct{}) (*Schema, error) {
|
func (reader *OpenapiReader) traverseSchema(root *Schema, tracker *tracker) (*Schema, error) {
|
||||||
// case primitive (or invalid)
|
// case primitive (or invalid)
|
||||||
if root.Type != Object && root.Type != Array {
|
if root.Type != Object && root.Type != Array {
|
||||||
return root, nil
|
return root, nil
|
||||||
}
|
}
|
||||||
// only root references are resolved
|
// only root references are resolved
|
||||||
if root.Reference != nil {
|
if root.Reference != nil {
|
||||||
return reader.safeResolveRefs(root, seenRefs)
|
return reader.safeResolveRefs(root, tracker)
|
||||||
}
|
}
|
||||||
// case struct
|
// case struct
|
||||||
if len(root.Properties) > 0 {
|
if len(root.Properties) > 0 {
|
||||||
for k, v := range root.Properties {
|
for k, v := range root.Properties {
|
||||||
childSchema, err := reader.safeResolveRefs(v, seenRefs)
|
childSchema, err := reader.safeResolveRefs(v, tracker)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -122,7 +121,7 @@ func (reader *OpenapiReader) traverseSchema(root *Schema, seenRefs map[string]st
|
||||||
}
|
}
|
||||||
// case array
|
// case array
|
||||||
if root.Items != nil {
|
if root.Items != nil {
|
||||||
itemsSchema, err := reader.safeResolveRefs(root.Items, seenRefs)
|
itemsSchema, err := reader.safeResolveRefs(root.Items, tracker)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -131,7 +130,7 @@ func (reader *OpenapiReader) traverseSchema(root *Schema, seenRefs map[string]st
|
||||||
// case map
|
// case map
|
||||||
additionionalProperties, ok := root.AdditionalProperties.(*Schema)
|
additionionalProperties, ok := root.AdditionalProperties.(*Schema)
|
||||||
if ok && additionionalProperties != nil {
|
if ok && additionionalProperties != nil {
|
||||||
valueSchema, err := reader.safeResolveRefs(additionionalProperties, seenRefs)
|
valueSchema, err := reader.safeResolveRefs(additionionalProperties, tracker)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -145,21 +144,11 @@ func (reader *OpenapiReader) readResolvedSchema(path string) (*Schema, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
seenRefs := make(map[string]struct{})
|
tracker := newTracker()
|
||||||
seenRefs[path] = struct{}{}
|
tracker.push(path, path)
|
||||||
root, err = reader.safeResolveRefs(root, seenRefs)
|
root, err = reader.safeResolveRefs(root, tracker)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
trace := ""
|
return nil, tracker.errWithTrace(err.Error(), "")
|
||||||
count := 0
|
|
||||||
for k := range seenRefs {
|
|
||||||
if count == len(seenRefs)-1 {
|
|
||||||
trace += k
|
|
||||||
break
|
|
||||||
}
|
|
||||||
trace += k + " -> "
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("%s. schema ref trace: %s", err, trace)
|
|
||||||
}
|
}
|
||||||
return root, nil
|
return root, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -226,7 +226,6 @@ func TestRootReferenceIsResolved(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelfReferenceLoopErrors(t *testing.T) {
|
func TestSelfReferenceLoopErrors(t *testing.T) {
|
||||||
t.Skip()
|
|
||||||
specString := `{
|
specString := `{
|
||||||
"components": {
|
"components": {
|
||||||
"schemas": {
|
"schemas": {
|
||||||
|
@ -257,11 +256,10 @@ func TestSelfReferenceLoopErrors(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = reader.readResolvedSchema("#/components/schemas/fruits")
|
_, err = reader.readResolvedSchema("#/components/schemas/fruits")
|
||||||
assert.ErrorContains(t, err, "references loop detected. schema ref trace: #/components/schemas/fruits -> #/components/schemas/foo")
|
assert.ErrorContains(t, err, "references loop detected. traversal trace: -> #/components/schemas/fruits -> #/components/schemas/foo")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCrossReferenceLoopErrors(t *testing.T) {
|
func TestCrossReferenceLoopErrors(t *testing.T) {
|
||||||
t.Skip()
|
|
||||||
specString := `{
|
specString := `{
|
||||||
"components": {
|
"components": {
|
||||||
"schemas": {
|
"schemas": {
|
||||||
|
@ -292,7 +290,7 @@ func TestCrossReferenceLoopErrors(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = reader.readResolvedSchema("#/components/schemas/fruits")
|
_, err = reader.readResolvedSchema("#/components/schemas/fruits")
|
||||||
assert.ErrorContains(t, err, "references loop detected. schema ref trace: #/components/schemas/fruits -> #/components/schemas/foo")
|
assert.ErrorContains(t, err, "references loop detected. traversal trace: -> #/components/schemas/fruits -> #/components/schemas/foo")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReferenceResolutionForMapInObject(t *testing.T) {
|
func TestReferenceResolutionForMapInObject(t *testing.T) {
|
||||||
|
|
|
@ -65,7 +65,7 @@ func New(golangType reflect.Type, docs *Docs) (*Schema, error) {
|
||||||
tracker := newTracker()
|
tracker := newTracker()
|
||||||
schema, err := safeToSchema(golangType, docs, "", tracker)
|
schema, err := safeToSchema(golangType, docs, "", tracker)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, tracker.errWithTrace(err.Error())
|
return nil, tracker.errWithTrace(err.Error(), "root")
|
||||||
}
|
}
|
||||||
return schema, nil
|
return schema, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -662,17 +662,17 @@ func TestEmbeddedStructSchema(t *testing.T) {
|
||||||
func TestErrorWithTrace(t *testing.T) {
|
func TestErrorWithTrace(t *testing.T) {
|
||||||
tracker := newTracker()
|
tracker := newTracker()
|
||||||
dummyType := reflect.TypeOf(struct{}{})
|
dummyType := reflect.TypeOf(struct{}{})
|
||||||
err := tracker.errWithTrace("with empty trace")
|
err := tracker.errWithTrace("with empty trace", "root")
|
||||||
assert.ErrorContains(t, err, "[ERROR] with empty trace. traversal trace: root")
|
assert.ErrorContains(t, err, "with empty trace. traversal trace: root")
|
||||||
|
|
||||||
tracker.push(dummyType, "resources")
|
tracker.push(dummyType, "resources")
|
||||||
err = tracker.errWithTrace("with depth = 1")
|
err = tracker.errWithTrace("with depth = 1", "root")
|
||||||
assert.ErrorContains(t, err, "[ERROR] with depth = 1. traversal trace: root -> resources")
|
assert.ErrorContains(t, err, "with depth = 1. traversal trace: root -> resources")
|
||||||
|
|
||||||
tracker.push(dummyType, "pipelines")
|
tracker.push(dummyType, "pipelines")
|
||||||
tracker.push(dummyType, "datasets")
|
tracker.push(dummyType, "datasets")
|
||||||
err = tracker.errWithTrace("with depth = 4")
|
err = tracker.errWithTrace("with depth = 4", "root")
|
||||||
assert.ErrorContains(t, err, "[ERROR] with depth = 4. traversal trace: root -> resources -> pipelines -> datasets")
|
assert.ErrorContains(t, err, "with depth = 4. traversal trace: root -> resources -> pipelines -> datasets")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNonAnnotatedFieldsAreSkipped(t *testing.T) {
|
func TestNonAnnotatedFieldsAreSkipped(t *testing.T) {
|
||||||
|
@ -1356,7 +1356,7 @@ func TestErrorIfStructRefersToItself(t *testing.T) {
|
||||||
|
|
||||||
elem := Foo{}
|
elem := Foo{}
|
||||||
_, err := New(reflect.TypeOf(elem), nil)
|
_, err := New(reflect.TypeOf(elem), nil)
|
||||||
assert.ErrorContains(t, err, "ERROR] cycle detected. traversal trace: root -> my_foo")
|
assert.ErrorContains(t, err, "cycle detected. traversal trace: root -> my_foo")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestErrorIfStructHasLoop(t *testing.T) {
|
func TestErrorIfStructHasLoop(t *testing.T) {
|
||||||
|
@ -1373,7 +1373,7 @@ func TestErrorIfStructHasLoop(t *testing.T) {
|
||||||
|
|
||||||
elem := Apple{}
|
elem := Apple{}
|
||||||
_, err := New(reflect.TypeOf(elem), nil)
|
_, err := New(reflect.TypeOf(elem), nil)
|
||||||
assert.ErrorContains(t, err, "[ERROR] cycle detected. traversal trace: root -> my_mango -> my_guava -> my_papaya -> my_apple")
|
assert.ErrorContains(t, err, "cycle detected. traversal trace: root -> my_mango -> my_guava -> my_papaya -> my_apple")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInterfaceGeneratesEmptySchema(t *testing.T) {
|
func TestInterfaceGeneratesEmptySchema(t *testing.T) {
|
||||||
|
|
|
@ -3,52 +3,51 @@ package schema
|
||||||
import (
|
import (
|
||||||
"container/list"
|
"container/list"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type tracker struct {
|
type tracker struct {
|
||||||
// Types encountered in current path during the recursive traversal. Used to
|
// Nodes encountered in current path during the recursive traversal. Used to
|
||||||
// check for cycles
|
// check for cycles
|
||||||
seenTypes map[reflect.Type]struct{}
|
seenNodes map[interface{}]struct{}
|
||||||
|
|
||||||
// List of field names encountered in current path during the recursive traversal.
|
// List of node names encountered in order in current path during the recursive traversal.
|
||||||
// Used to hydrate errors with path to the exact node where error occured.
|
// Used to hydrate errors with path to the exact node where error occured.
|
||||||
//
|
//
|
||||||
// The field names here are the first tag in the json tags of struct field.
|
// NOTE: node and node names can be the same
|
||||||
debugTrace *list.List
|
listOfNodes *list.List
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTracker() *tracker {
|
func newTracker() *tracker {
|
||||||
return &tracker{
|
return &tracker{
|
||||||
seenTypes: map[reflect.Type]struct{}{},
|
seenNodes: map[interface{}]struct{}{},
|
||||||
debugTrace: list.New(),
|
listOfNodes: list.New(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tracker) errWithTrace(prefix string) error {
|
func (t *tracker) errWithTrace(prefix string, initTrace string) error {
|
||||||
traceString := "root"
|
traceString := initTrace
|
||||||
curr := t.debugTrace.Front()
|
curr := t.listOfNodes.Front()
|
||||||
for curr != nil {
|
for curr != nil {
|
||||||
if curr.Value.(string) != "" {
|
if curr.Value.(string) != "" {
|
||||||
traceString += " -> " + curr.Value.(string)
|
traceString += " -> " + curr.Value.(string)
|
||||||
}
|
}
|
||||||
curr = curr.Next()
|
curr = curr.Next()
|
||||||
}
|
}
|
||||||
return fmt.Errorf("[ERROR] " + prefix + ". traversal trace: " + traceString)
|
return fmt.Errorf(prefix + ". traversal trace: " + traceString)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tracker) hasCycle(golangType reflect.Type) bool {
|
func (t *tracker) hasCycle(node interface{}) bool {
|
||||||
_, ok := t.seenTypes[golangType]
|
_, ok := t.seenNodes[node]
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tracker) push(nodeType reflect.Type, jsonName string) {
|
func (t *tracker) push(node interface{}, name string) {
|
||||||
t.seenTypes[nodeType] = struct{}{}
|
t.seenNodes[node] = struct{}{}
|
||||||
t.debugTrace.PushBack(jsonName)
|
t.listOfNodes.PushBack(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tracker) pop(nodeType reflect.Type) {
|
func (t *tracker) pop(nodeType interface{}) {
|
||||||
back := t.debugTrace.Back()
|
back := t.listOfNodes.Back()
|
||||||
t.debugTrace.Remove(back)
|
t.listOfNodes.Remove(back)
|
||||||
delete(t.seenTypes, nodeType)
|
delete(t.seenNodes, nodeType)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue