added tracker struct and refactored a bit

This commit is contained in:
Shreyas Goenka 2023-01-19 18:23:37 +01:00
parent 3c5ee69941
commit 333b9f7acc
No known key found for this signature in database
GPG Key ID: 92A07DF49CCB0622
3 changed files with 76 additions and 42 deletions

View File

@ -61,11 +61,10 @@ between json schema types and golang types
for details visit: https://json-schema.org/understanding-json-schema/reference/object.html#properties
*/
func NewSchema(golangType reflect.Type, docs *Docs) (*Schema, error) {
seenTypes := map[reflect.Type]struct{}{}
debugTrace := list.New()
schema, err := safeToSchema(golangType, docs, "", seenTypes, debugTrace)
tracker := newTracker()
schema, err := safeToSchema(golangType, docs, "", tracker)
if err != nil {
return nil, errWithTrace(err.Error(), debugTrace)
return nil, tracker.errWithTrace(err.Error())
}
return schema, nil
}
@ -106,20 +105,8 @@ func javascriptType(golangType reflect.Type) (JavascriptType, error) {
}
}
func errWithTrace(prefix string, trace *list.List) error {
traceString := "root"
curr := trace.Front()
for curr != nil {
if curr.Value.(string) != "" {
traceString += " -> " + curr.Value.(string)
}
curr = curr.Next()
}
return fmt.Errorf("[ERROR] " + prefix + ". traversal trace: " + traceString)
}
// A wrapper over toSchema function to detect cycles in the bundle config struct
func safeToSchema(golangType reflect.Type, docs *Docs, debugTraceId string, seenTypes map[reflect.Type]struct{}, debugTrace *list.List) (*Schema, error) {
func safeToSchema(golangType reflect.Type, docs *Docs, traceId string, tracker *tracker) (*Schema, error) {
// WE ERROR OUT IF THERE ARE CYCLES IN THE JSON SCHEMA
// There are mechanisms to deal with cycles though recursive identifiers in json
// schema. However if we use them, we would need to make sure we are able to detect
@ -127,24 +114,16 @@ func safeToSchema(golangType reflect.Type, docs *Docs, debugTraceId string, seen
//
// see: https://json-schema.org/understanding-json-schema/structuring.html#recursion
// for details
_, ok := seenTypes[golangType]
if ok {
fmt.Println("[DEBUG] traceSet: ", seenTypes)
if tracker.hasCycle(golangType) {
return nil, fmt.Errorf("cycle detected")
}
// Update set of types in current path
seenTypes[golangType] = struct{}{}
// Add the json tag name of struct field to debug trace
debugTrace.PushBack(debugTraceId)
props, err := toSchema(golangType, docs, seenTypes, debugTrace)
tracker.step(golangType, traceId)
props, err := toSchema(golangType, docs, tracker)
if err != nil {
return nil, err
}
back := debugTrace.Back()
debugTrace.Remove(back)
delete(seenTypes, golangType)
tracker.undoStep(golangType)
return props, nil
}
@ -188,10 +167,10 @@ func getStructFields(golangType reflect.Type) []reflect.StructField {
// Used to identify cycles.
// debugTrace: linked list of golang types encounted. In case of errors this
// helps log where the error originated from
func toSchema(golangType reflect.Type, docs *Docs, seenTypes map[reflect.Type]struct{}, debugTrace *list.List) (*Schema, error) {
func toSchema(golangType reflect.Type, docs *Docs, tracker *tracker) (*Schema, error) {
// *Struct and Struct generate identical json schemas
if golangType.Kind() == reflect.Pointer {
return safeToSchema(golangType.Elem(), docs, "", seenTypes, debugTrace)
return safeToSchema(golangType.Elem(), docs, "", tracker)
}
if golangType.Kind() == reflect.Interface {
@ -216,7 +195,7 @@ func toSchema(golangType reflect.Type, docs *Docs, seenTypes map[reflect.Type]st
if err != nil {
return nil, err
}
elemProps, err := safeToSchema(elemGolangType, docs, "", seenTypes, debugTrace)
elemProps, err := safeToSchema(elemGolangType, docs, "", tracker)
if err != nil {
return nil, err
}
@ -235,7 +214,7 @@ func toSchema(golangType reflect.Type, docs *Docs, seenTypes map[reflect.Type]st
if golangType.Key().Kind() != reflect.String {
return nil, fmt.Errorf("only string keyed maps allowed")
}
additionalProperties, err = safeToSchema(golangType.Elem(), docs, "", seenTypes, debugTrace)
additionalProperties, err = safeToSchema(golangType.Elem(), docs, "", tracker)
if err != nil {
return nil, err
}
@ -278,7 +257,7 @@ func toSchema(golangType reflect.Type, docs *Docs, seenTypes map[reflect.Type]st
}
// compute Schema.Properties for the child recursively
fieldProps, err := safeToSchema(child.Type, childDocs, childName, seenTypes, debugTrace)
fieldProps, err := safeToSchema(child.Type, childDocs, childName, tracker)
if err != nil {
return nil, err
}

View File

@ -1,7 +1,6 @@
package schema
import (
"container/list"
"encoding/json"
"reflect"
"testing"
@ -661,17 +660,18 @@ func TestEmbeddedStructSchema(t *testing.T) {
}
func TestErrorWithTrace(t *testing.T) {
debugTrace := list.New()
err := errWithTrace("with empty trace", debugTrace)
tracker := newTracker()
dummyType := reflect.TypeOf(struct{}{})
err := tracker.errWithTrace("with empty trace")
assert.ErrorContains(t, err, "[ERROR] with empty trace. traversal trace: root")
debugTrace.PushBack("resources")
err = errWithTrace("with depth = 1", debugTrace)
tracker.step(dummyType, "resources")
err = tracker.errWithTrace("with depth = 1")
assert.ErrorContains(t, err, "[ERROR] with depth = 1. traversal trace: root -> resources")
debugTrace.PushBack("pipelines")
debugTrace.PushBack("datasets")
err = errWithTrace("with depth = 4", debugTrace)
tracker.step(dummyType, "pipelines")
tracker.step(dummyType, "datasets")
err = tracker.errWithTrace("with depth = 4")
assert.ErrorContains(t, err, "[ERROR] with depth = 4. traversal trace: root -> resources -> pipelines -> datasets")
}

55
bundle/schema/tracker.go Normal file
View File

@ -0,0 +1,55 @@
package schema
import (
"container/list"
"fmt"
"reflect"
)
type tracker struct {
// Types encountered in path of reaching the current type. Used to deletect
// cycles
seenTypes map[reflect.Type]struct{}
// List of names from json tag encountered while reaching current type. This
// is logged on any error so we know on which type an error occured
debugTrace *list.List
}
func newTracker() *tracker {
return &tracker{
seenTypes: map[reflect.Type]struct{}{},
debugTrace: list.New(),
}
}
func (t *tracker) errWithTrace(prefix string) error {
traceString := "root"
curr := t.debugTrace.Front()
for curr != nil {
if curr.Value.(string) != "" {
traceString += " -> " + curr.Value.(string)
}
curr = curr.Next()
}
return fmt.Errorf("[ERROR] " + prefix + ". traversal trace: " + traceString)
}
func (t *tracker) hasCycle(golangType reflect.Type) bool {
_, ok := t.seenTypes[golangType]
if !ok {
fmt.Println("[DEBUG] traceSet for cycle: ", t.seenTypes)
}
return ok
}
func (t *tracker) step(nodeType reflect.Type, jsonName string) {
t.seenTypes[nodeType] = struct{}{}
t.debugTrace.PushBack(jsonName)
}
func (t *tracker) undoStep(nodeType reflect.Type) {
back := t.debugTrace.Back()
t.debugTrace.Remove(back)
delete(t.seenTypes, nodeType)
}