From 00e5896966b5b386ceff2eddceaa3689ab9a8792 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Tue, 20 Aug 2024 21:53:04 +0200 Subject: [PATCH] add test for recursive --- libs/jsonschema/from_type.go | 70 ++++++++++++++++-------- libs/jsonschema/from_type_test.go | 64 ++++++++++++++++++++++ libs/jsonschema/test_types/test_types.go | 15 +++++ 3 files changed, 127 insertions(+), 22 deletions(-) create mode 100644 libs/jsonschema/test_types/test_types.go diff --git a/libs/jsonschema/from_type.go b/libs/jsonschema/from_type.go index ee29684b..2c1d4ca6 100644 --- a/libs/jsonschema/from_type.go +++ b/libs/jsonschema/from_type.go @@ -35,6 +35,8 @@ type constructor struct { // Example key: github.com/databricks/databricks-sdk-go/service/jobs.JobSettings definitions map[string]Schema + seen map[string]struct{} + // Transformation function to apply after generating a node in the schema. fn func(s Schema) Schema } @@ -74,10 +76,11 @@ func (c *constructor) nestedDefinitions() any { func FromType(typ reflect.Type, fn func(s Schema) Schema) (Schema, error) { c := constructor{ definitions: make(map[string]Schema), + seen: make(map[string]struct{}), fn: fn, } - _, err := c.walk(typ) + err := c.walk(typ) if err != nil { return InvalidSchema, err } @@ -90,6 +93,11 @@ func FromType(typ reflect.Type, fn func(s Schema) Schema) (Schema, error) { } func typePath(typ reflect.Type) string { + // Pointers have a typ.Name() of "". Dereference them to get the underlying type. + for typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + // typ.Name() resolves to "" for any type. if typ.Kind() == reflect.Interface { return "interface" @@ -105,7 +113,7 @@ func typePath(typ reflect.Type) string { // TODO: would a worked based model fit better here? Is this internal API not // the right fit? -func (c *constructor) walk(typ reflect.Type) (string, error) { +func (c *constructor) walk(typ reflect.Type) error { // Dereference pointers if necessary. for typ.Kind() == reflect.Ptr { typ = typ.Elem() @@ -113,9 +121,14 @@ func (c *constructor) walk(typ reflect.Type) (string, error) { typPath := typePath(typ) - // Return value directly if it's already been processed. + // Keep track of seen types to avoid infinite recursion. + if _, ok := c.seen[typPath]; !ok { + c.seen[typPath] = struct{}{} + } + + // Return early directly if it's already been processed. if _, ok := c.definitions[typPath]; ok { - return typPath, nil + return nil } var s Schema @@ -144,10 +157,10 @@ func (c *constructor) walk(typ reflect.Type) (string, error) { // set to null and disallowed in the schema. s = Schema{Type: NullType} default: - return "", fmt.Errorf("unsupported type: %s", typ.Kind()) + return fmt.Errorf("unsupported type: %s", typ.Kind()) } if err != nil { - return "", err + return err } if c.fn != nil { @@ -158,7 +171,7 @@ func (c *constructor) walk(typ reflect.Type) (string, error) { // TODO: Apply transformation at the end, to all definitions instead of // during recursive traversal? c.definitions[typPath] = s - return typPath, nil + return nil } // This function returns all member fields of the provided type. @@ -193,6 +206,7 @@ func getStructFields(typ reflect.Type) []reflect.StructField { return fields } +// TODO: get rid of the errors here and panic instead? func (c *constructor) fromTypeStruct(typ reflect.Type) (Schema, error) { if typ.Kind() != reflect.Struct { return InvalidSchema, fmt.Errorf("expected struct, got %s", typ.Kind()) @@ -233,11 +247,15 @@ func (c *constructor) fromTypeStruct(typ reflect.Type) (Schema, error) { res.Required = append(res.Required, jsonTags[0]) } - // Trigger call to fromType, to recursively generate definitions for - // the struct field. - typPath, err := c.walk(structField.Type) - if err != nil { - return InvalidSchema, err + typPath := typePath(structField.Type) + // Only walk if the type has not been seen yet. + if _, ok := c.seen[typPath]; !ok { + // Trigger call to fromType, to recursively generate definitions for + // the struct field. + err := c.walk(structField.Type) + if err != nil { + return InvalidSchema, err + } } refPath := path.Join("#/$defs", typPath) @@ -261,11 +279,15 @@ func (c *constructor) fromTypeSlice(typ reflect.Type) (Schema, error) { Type: ArrayType, } - // Trigger call to fromType, to recursively generate definitions for - // the slice element. - typPath, err := c.walk(typ.Elem()) - if err != nil { - return InvalidSchema, err + typPath := typePath(typ.Elem()) + // Only walk if the type has not been seen yet. + if _, ok := c.seen[typPath]; !ok { + // Trigger call to fromType, to recursively generate definitions for + // the slice element. + err := c.walk(typ.Elem()) + if err != nil { + return InvalidSchema, err + } } refPath := path.Join("#/$defs", typPath) @@ -290,11 +312,15 @@ func (c *constructor) fromTypeMap(typ reflect.Type) (Schema, error) { Type: ObjectType, } - // Trigger call to fromType, to recursively generate definitions for - // the map value. - typPath, err := c.walk(typ.Elem()) - if err != nil { - return InvalidSchema, err + typPath := typePath(typ.Elem()) + // Only walk if the type has not been seen yet. + if _, ok := c.seen[typPath]; !ok { + // Trigger call to fromType, to recursively generate definitions for + // the map value. + err := c.walk(typ.Elem()) + if err != nil { + return InvalidSchema, err + } } refPath := path.Join("#/$defs", typPath) diff --git a/libs/jsonschema/from_type_test.go b/libs/jsonschema/from_type_test.go index e70e75c4..24d6e99c 100644 --- a/libs/jsonschema/from_type_test.go +++ b/libs/jsonschema/from_type_test.go @@ -1,9 +1,11 @@ package jsonschema import ( + "encoding/json" "reflect" "testing" + "github.com/databricks/cli/libs/jsonschema/test_types" "github.com/stretchr/testify/assert" ) @@ -259,3 +261,65 @@ func TestFromTypeNested(t *testing.T) { }) } } + +// TODO: Call out in the PR description that recursive Go types are supported. +func TestFromTypeRecursive(t *testing.T) { + fooRef := "#/$defs/github.com/databricks/cli/libs/jsonschema/test_types.Foo" + barRef := "#/$defs/github.com/databricks/cli/libs/jsonschema/test_types.Bar" + + expected := Schema{ + Type: "object", + Definitions: map[string]any{ + "github.com": map[string]any{ + "databricks": map[string]any{ + "cli": map[string]any{ + "libs": map[string]any{ + "jsonschema": map[string]any{ + "test_types.Bar": Schema{ + Type: "object", + Properties: map[string]*Schema{ + "foo": { + Reference: &fooRef, + }, + }, + AdditionalProperties: false, + Required: []string{}, + }, + "test_types.Foo": Schema{ + Type: "object", + Properties: map[string]*Schema{ + "bar": { + Reference: &barRef, + }, + }, + AdditionalProperties: false, + Required: []string{}, + }, + }, + }, + }, + }, + }, + }, + Properties: map[string]*Schema{ + "foo": { + Reference: &fooRef, + }, + }, + AdditionalProperties: false, + Required: []string{"foo"}, + } + + s, err := FromType(reflect.TypeOf(test_types.Outer{}), nil) + assert.NoError(t, err) + assert.Equal(t, expected, s) + + jsonSchema, err := json.MarshalIndent(s, " ", " ") + assert.NoError(t, err) + + expectedJson, err := json.MarshalIndent(expected, " ", " ") + assert.NoError(t, err) + + t.Log("[DEBUG] actual: ", string(jsonSchema)) + t.Log("[DEBUG] expected: ", string(expectedJson)) +} diff --git a/libs/jsonschema/test_types/test_types.go b/libs/jsonschema/test_types/test_types.go new file mode 100644 index 00000000..febe5c33 --- /dev/null +++ b/libs/jsonschema/test_types/test_types.go @@ -0,0 +1,15 @@ +package test_types + +// Recursive types cannot be defined inline without making them anonymous, +// so we define them here instead. +type Foo struct { + Bar *Bar `json:"bar,omitempty"` +} + +type Bar struct { + Foo Foo `json:"foo,omitempty"` +} + +type Outer struct { + Foo Foo `json:"foo"` +}