add test for recursive

This commit is contained in:
Shreyas Goenka 2024-08-20 21:53:04 +02:00
parent 460eeb928d
commit 00e5896966
No known key found for this signature in database
GPG Key ID: 92A07DF49CCB0622
3 changed files with 127 additions and 22 deletions

View File

@ -35,6 +35,8 @@ type constructor struct {
// Example key: github.com/databricks/databricks-sdk-go/service/jobs.JobSettings
definitions map[string]Schema
seen map[string]struct{}
// Transformation function to apply after generating a node in the schema.
fn func(s Schema) Schema
}
@ -74,10 +76,11 @@ func (c *constructor) nestedDefinitions() any {
func FromType(typ reflect.Type, fn func(s Schema) Schema) (Schema, error) {
c := constructor{
definitions: make(map[string]Schema),
seen: make(map[string]struct{}),
fn: fn,
}
_, err := c.walk(typ)
err := c.walk(typ)
if err != nil {
return InvalidSchema, err
}
@ -90,6 +93,11 @@ func FromType(typ reflect.Type, fn func(s Schema) Schema) (Schema, error) {
}
func typePath(typ reflect.Type) string {
// Pointers have a typ.Name() of "". Dereference them to get the underlying type.
for typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
// typ.Name() resolves to "" for any type.
if typ.Kind() == reflect.Interface {
return "interface"
@ -105,7 +113,7 @@ func typePath(typ reflect.Type) string {
// TODO: would a worked based model fit better here? Is this internal API not
// the right fit?
func (c *constructor) walk(typ reflect.Type) (string, error) {
func (c *constructor) walk(typ reflect.Type) error {
// Dereference pointers if necessary.
for typ.Kind() == reflect.Ptr {
typ = typ.Elem()
@ -113,9 +121,14 @@ func (c *constructor) walk(typ reflect.Type) (string, error) {
typPath := typePath(typ)
// Return value directly if it's already been processed.
// Keep track of seen types to avoid infinite recursion.
if _, ok := c.seen[typPath]; !ok {
c.seen[typPath] = struct{}{}
}
// Return early directly if it's already been processed.
if _, ok := c.definitions[typPath]; ok {
return typPath, nil
return nil
}
var s Schema
@ -144,10 +157,10 @@ func (c *constructor) walk(typ reflect.Type) (string, error) {
// set to null and disallowed in the schema.
s = Schema{Type: NullType}
default:
return "", fmt.Errorf("unsupported type: %s", typ.Kind())
return fmt.Errorf("unsupported type: %s", typ.Kind())
}
if err != nil {
return "", err
return err
}
if c.fn != nil {
@ -158,7 +171,7 @@ func (c *constructor) walk(typ reflect.Type) (string, error) {
// TODO: Apply transformation at the end, to all definitions instead of
// during recursive traversal?
c.definitions[typPath] = s
return typPath, nil
return nil
}
// This function returns all member fields of the provided type.
@ -193,6 +206,7 @@ func getStructFields(typ reflect.Type) []reflect.StructField {
return fields
}
// TODO: get rid of the errors here and panic instead?
func (c *constructor) fromTypeStruct(typ reflect.Type) (Schema, error) {
if typ.Kind() != reflect.Struct {
return InvalidSchema, fmt.Errorf("expected struct, got %s", typ.Kind())
@ -233,11 +247,15 @@ func (c *constructor) fromTypeStruct(typ reflect.Type) (Schema, error) {
res.Required = append(res.Required, jsonTags[0])
}
// Trigger call to fromType, to recursively generate definitions for
// the struct field.
typPath, err := c.walk(structField.Type)
if err != nil {
return InvalidSchema, err
typPath := typePath(structField.Type)
// Only walk if the type has not been seen yet.
if _, ok := c.seen[typPath]; !ok {
// Trigger call to fromType, to recursively generate definitions for
// the struct field.
err := c.walk(structField.Type)
if err != nil {
return InvalidSchema, err
}
}
refPath := path.Join("#/$defs", typPath)
@ -261,11 +279,15 @@ func (c *constructor) fromTypeSlice(typ reflect.Type) (Schema, error) {
Type: ArrayType,
}
// Trigger call to fromType, to recursively generate definitions for
// the slice element.
typPath, err := c.walk(typ.Elem())
if err != nil {
return InvalidSchema, err
typPath := typePath(typ.Elem())
// Only walk if the type has not been seen yet.
if _, ok := c.seen[typPath]; !ok {
// Trigger call to fromType, to recursively generate definitions for
// the slice element.
err := c.walk(typ.Elem())
if err != nil {
return InvalidSchema, err
}
}
refPath := path.Join("#/$defs", typPath)
@ -290,11 +312,15 @@ func (c *constructor) fromTypeMap(typ reflect.Type) (Schema, error) {
Type: ObjectType,
}
// Trigger call to fromType, to recursively generate definitions for
// the map value.
typPath, err := c.walk(typ.Elem())
if err != nil {
return InvalidSchema, err
typPath := typePath(typ.Elem())
// Only walk if the type has not been seen yet.
if _, ok := c.seen[typPath]; !ok {
// Trigger call to fromType, to recursively generate definitions for
// the map value.
err := c.walk(typ.Elem())
if err != nil {
return InvalidSchema, err
}
}
refPath := path.Join("#/$defs", typPath)

View File

@ -1,9 +1,11 @@
package jsonschema
import (
"encoding/json"
"reflect"
"testing"
"github.com/databricks/cli/libs/jsonschema/test_types"
"github.com/stretchr/testify/assert"
)
@ -259,3 +261,65 @@ func TestFromTypeNested(t *testing.T) {
})
}
}
// TODO: Call out in the PR description that recursive Go types are supported.
func TestFromTypeRecursive(t *testing.T) {
fooRef := "#/$defs/github.com/databricks/cli/libs/jsonschema/test_types.Foo"
barRef := "#/$defs/github.com/databricks/cli/libs/jsonschema/test_types.Bar"
expected := Schema{
Type: "object",
Definitions: map[string]any{
"github.com": map[string]any{
"databricks": map[string]any{
"cli": map[string]any{
"libs": map[string]any{
"jsonschema": map[string]any{
"test_types.Bar": Schema{
Type: "object",
Properties: map[string]*Schema{
"foo": {
Reference: &fooRef,
},
},
AdditionalProperties: false,
Required: []string{},
},
"test_types.Foo": Schema{
Type: "object",
Properties: map[string]*Schema{
"bar": {
Reference: &barRef,
},
},
AdditionalProperties: false,
Required: []string{},
},
},
},
},
},
},
},
Properties: map[string]*Schema{
"foo": {
Reference: &fooRef,
},
},
AdditionalProperties: false,
Required: []string{"foo"},
}
s, err := FromType(reflect.TypeOf(test_types.Outer{}), nil)
assert.NoError(t, err)
assert.Equal(t, expected, s)
jsonSchema, err := json.MarshalIndent(s, " ", " ")
assert.NoError(t, err)
expectedJson, err := json.MarshalIndent(expected, " ", " ")
assert.NoError(t, err)
t.Log("[DEBUG] actual: ", string(jsonSchema))
t.Log("[DEBUG] expected: ", string(expectedJson))
}

View File

@ -0,0 +1,15 @@
package test_types
// Recursive types cannot be defined inline without making them anonymous,
// so we define them here instead.
type Foo struct {
Bar *Bar `json:"bar,omitempty"`
}
type Bar struct {
Foo Foo `json:"foo,omitempty"`
}
type Outer struct {
Foo Foo `json:"foo"`
}