add test for recursive

This commit is contained in:
Shreyas Goenka 2024-08-20 21:53:04 +02:00
parent 460eeb928d
commit 00e5896966
No known key found for this signature in database
GPG Key ID: 92A07DF49CCB0622
3 changed files with 127 additions and 22 deletions

View File

@ -35,6 +35,8 @@ type constructor struct {
// Example key: github.com/databricks/databricks-sdk-go/service/jobs.JobSettings // Example key: github.com/databricks/databricks-sdk-go/service/jobs.JobSettings
definitions map[string]Schema definitions map[string]Schema
seen map[string]struct{}
// Transformation function to apply after generating a node in the schema. // Transformation function to apply after generating a node in the schema.
fn func(s Schema) Schema fn func(s Schema) Schema
} }
@ -74,10 +76,11 @@ func (c *constructor) nestedDefinitions() any {
func FromType(typ reflect.Type, fn func(s Schema) Schema) (Schema, error) { func FromType(typ reflect.Type, fn func(s Schema) Schema) (Schema, error) {
c := constructor{ c := constructor{
definitions: make(map[string]Schema), definitions: make(map[string]Schema),
seen: make(map[string]struct{}),
fn: fn, fn: fn,
} }
_, err := c.walk(typ) err := c.walk(typ)
if err != nil { if err != nil {
return InvalidSchema, err return InvalidSchema, err
} }
@ -90,6 +93,11 @@ func FromType(typ reflect.Type, fn func(s Schema) Schema) (Schema, error) {
} }
func typePath(typ reflect.Type) string { func typePath(typ reflect.Type) string {
// Pointers have a typ.Name() of "". Dereference them to get the underlying type.
for typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
// typ.Name() resolves to "" for any type. // typ.Name() resolves to "" for any type.
if typ.Kind() == reflect.Interface { if typ.Kind() == reflect.Interface {
return "interface" return "interface"
@ -105,7 +113,7 @@ func typePath(typ reflect.Type) string {
// TODO: would a worked based model fit better here? Is this internal API not // TODO: would a worked based model fit better here? Is this internal API not
// the right fit? // the right fit?
func (c *constructor) walk(typ reflect.Type) (string, error) { func (c *constructor) walk(typ reflect.Type) error {
// Dereference pointers if necessary. // Dereference pointers if necessary.
for typ.Kind() == reflect.Ptr { for typ.Kind() == reflect.Ptr {
typ = typ.Elem() typ = typ.Elem()
@ -113,9 +121,14 @@ func (c *constructor) walk(typ reflect.Type) (string, error) {
typPath := typePath(typ) typPath := typePath(typ)
// Return value directly if it's already been processed. // Keep track of seen types to avoid infinite recursion.
if _, ok := c.seen[typPath]; !ok {
c.seen[typPath] = struct{}{}
}
// Return early directly if it's already been processed.
if _, ok := c.definitions[typPath]; ok { if _, ok := c.definitions[typPath]; ok {
return typPath, nil return nil
} }
var s Schema var s Schema
@ -144,10 +157,10 @@ func (c *constructor) walk(typ reflect.Type) (string, error) {
// set to null and disallowed in the schema. // set to null and disallowed in the schema.
s = Schema{Type: NullType} s = Schema{Type: NullType}
default: default:
return "", fmt.Errorf("unsupported type: %s", typ.Kind()) return fmt.Errorf("unsupported type: %s", typ.Kind())
} }
if err != nil { if err != nil {
return "", err return err
} }
if c.fn != nil { if c.fn != nil {
@ -158,7 +171,7 @@ func (c *constructor) walk(typ reflect.Type) (string, error) {
// TODO: Apply transformation at the end, to all definitions instead of // TODO: Apply transformation at the end, to all definitions instead of
// during recursive traversal? // during recursive traversal?
c.definitions[typPath] = s c.definitions[typPath] = s
return typPath, nil return nil
} }
// This function returns all member fields of the provided type. // This function returns all member fields of the provided type.
@ -193,6 +206,7 @@ func getStructFields(typ reflect.Type) []reflect.StructField {
return fields return fields
} }
// TODO: get rid of the errors here and panic instead?
func (c *constructor) fromTypeStruct(typ reflect.Type) (Schema, error) { func (c *constructor) fromTypeStruct(typ reflect.Type) (Schema, error) {
if typ.Kind() != reflect.Struct { if typ.Kind() != reflect.Struct {
return InvalidSchema, fmt.Errorf("expected struct, got %s", typ.Kind()) return InvalidSchema, fmt.Errorf("expected struct, got %s", typ.Kind())
@ -233,12 +247,16 @@ func (c *constructor) fromTypeStruct(typ reflect.Type) (Schema, error) {
res.Required = append(res.Required, jsonTags[0]) res.Required = append(res.Required, jsonTags[0])
} }
typPath := typePath(structField.Type)
// Only walk if the type has not been seen yet.
if _, ok := c.seen[typPath]; !ok {
// Trigger call to fromType, to recursively generate definitions for // Trigger call to fromType, to recursively generate definitions for
// the struct field. // the struct field.
typPath, err := c.walk(structField.Type) err := c.walk(structField.Type)
if err != nil { if err != nil {
return InvalidSchema, err return InvalidSchema, err
} }
}
refPath := path.Join("#/$defs", typPath) refPath := path.Join("#/$defs", typPath)
// For non-built-in types, refer to the definition. // For non-built-in types, refer to the definition.
@ -261,12 +279,16 @@ func (c *constructor) fromTypeSlice(typ reflect.Type) (Schema, error) {
Type: ArrayType, Type: ArrayType,
} }
typPath := typePath(typ.Elem())
// Only walk if the type has not been seen yet.
if _, ok := c.seen[typPath]; !ok {
// Trigger call to fromType, to recursively generate definitions for // Trigger call to fromType, to recursively generate definitions for
// the slice element. // the slice element.
typPath, err := c.walk(typ.Elem()) err := c.walk(typ.Elem())
if err != nil { if err != nil {
return InvalidSchema, err return InvalidSchema, err
} }
}
refPath := path.Join("#/$defs", typPath) refPath := path.Join("#/$defs", typPath)
@ -290,12 +312,16 @@ func (c *constructor) fromTypeMap(typ reflect.Type) (Schema, error) {
Type: ObjectType, Type: ObjectType,
} }
typPath := typePath(typ.Elem())
// Only walk if the type has not been seen yet.
if _, ok := c.seen[typPath]; !ok {
// Trigger call to fromType, to recursively generate definitions for // Trigger call to fromType, to recursively generate definitions for
// the map value. // the map value.
typPath, err := c.walk(typ.Elem()) err := c.walk(typ.Elem())
if err != nil { if err != nil {
return InvalidSchema, err return InvalidSchema, err
} }
}
refPath := path.Join("#/$defs", typPath) refPath := path.Join("#/$defs", typPath)

View File

@ -1,9 +1,11 @@
package jsonschema package jsonschema
import ( import (
"encoding/json"
"reflect" "reflect"
"testing" "testing"
"github.com/databricks/cli/libs/jsonschema/test_types"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -259,3 +261,65 @@ func TestFromTypeNested(t *testing.T) {
}) })
} }
} }
// TODO: Call out in the PR description that recursive Go types are supported.
func TestFromTypeRecursive(t *testing.T) {
fooRef := "#/$defs/github.com/databricks/cli/libs/jsonschema/test_types.Foo"
barRef := "#/$defs/github.com/databricks/cli/libs/jsonschema/test_types.Bar"
expected := Schema{
Type: "object",
Definitions: map[string]any{
"github.com": map[string]any{
"databricks": map[string]any{
"cli": map[string]any{
"libs": map[string]any{
"jsonschema": map[string]any{
"test_types.Bar": Schema{
Type: "object",
Properties: map[string]*Schema{
"foo": {
Reference: &fooRef,
},
},
AdditionalProperties: false,
Required: []string{},
},
"test_types.Foo": Schema{
Type: "object",
Properties: map[string]*Schema{
"bar": {
Reference: &barRef,
},
},
AdditionalProperties: false,
Required: []string{},
},
},
},
},
},
},
},
Properties: map[string]*Schema{
"foo": {
Reference: &fooRef,
},
},
AdditionalProperties: false,
Required: []string{"foo"},
}
s, err := FromType(reflect.TypeOf(test_types.Outer{}), nil)
assert.NoError(t, err)
assert.Equal(t, expected, s)
jsonSchema, err := json.MarshalIndent(s, " ", " ")
assert.NoError(t, err)
expectedJson, err := json.MarshalIndent(expected, " ", " ")
assert.NoError(t, err)
t.Log("[DEBUG] actual: ", string(jsonSchema))
t.Log("[DEBUG] expected: ", string(expectedJson))
}

View File

@ -0,0 +1,15 @@
package test_types
// Recursive types cannot be defined inline without making them anonymous,
// so we define them here instead.
type Foo struct {
Bar *Bar `json:"bar,omitempty"`
}
type Bar struct {
Foo Foo `json:"foo,omitempty"`
}
type Outer struct {
Foo Foo `json:"foo"`
}