cleaned up test boilerplate, added support for maps and embedded types

This commit is contained in:
Shreyas Goenka 2023-01-13 19:08:21 +01:00
parent 010f60d498
commit b1e85e20f6
No known key found for this signature in database
GPG Key ID: 92A07DF49CCB0622
2 changed files with 140 additions and 72 deletions

View File

@ -1,6 +1,7 @@
package schema package schema
import ( import (
"container/list"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
@ -25,8 +26,8 @@ type Property struct {
// TODO: panic for now, add support for adding schemas to $defs in case of cycles // TODO: panic for now, add support for adding schemas to $defs in case of cycles
type Item struct { type Item struct {
Type JsType `json:"type"` Type JsType `json:"type"`
Properities map[string]*Property `json:"properties,omitempty"` Properities map[string]*Property `json:"properties,omitempty"`
} }
func NewSchema(golangType reflect.Type) (*Schema, error) { func NewSchema(golangType reflect.Type) (*Schema, error) {
@ -43,6 +44,8 @@ func NewSchema(golangType reflect.Type) (*Schema, error) {
}, nil }, nil
} }
// TODO: add tests for errors being triggered
type JsType string type JsType string
const ( const (
@ -67,7 +70,6 @@ func javascriptType(golangType reflect.Type) (JsType, error) {
return Number, nil return Number, nil
case reflect.Struct: case reflect.Struct:
return Object, nil return Object, nil
// TODO: add support for pattern properities to account for maps
case reflect.Map: case reflect.Map:
if golangType.Key().Kind() != reflect.String { if golangType.Key().Kind() != reflect.String {
return Invalid, fmt.Errorf("only strings map keys are valid. key type: %v", golangType.Key().Kind()) return Invalid, fmt.Errorf("only strings map keys are valid. key type: %v", golangType.Key().Kind())
@ -94,60 +96,114 @@ func errWithTrace(prefix string, trace []reflect.Type) error {
// TODO: add tests for the error cases, forcefully triggering them // TODO: add tests for the error cases, forcefully triggering them
// checks and errors out for cycles
// wraps the error with context
func safeToProperty(golangType reflect.Type, traceSet map[reflect.Type]struct{}, traceSlice []reflect.Type) (*Property, error) {
traceSlice = append(traceSlice, golangType)
// detect cycles. Fail if a cycle is detected
// TODO: Add references here for cycles
// TODO: move this check somewhere nicer
_, ok := traceSet[golangType]
if ok {
fmt.Println("[DEBUG] traceSet: ", traceSet)
return nil, errWithTrace("cycle detected", traceSlice)
}
// add current child field to history
traceSet[golangType] = struct{}{}
props, err := toProperity(golangType, traceSet, traceSlice)
if err != nil {
return nil, errWithTrace(err.Error(), traceSlice)
}
delete(traceSet, golangType)
traceSlice = traceSlice[:len(traceSlice)-1]
return props, nil
}
func pop(q []reflect.StructField) reflect.StructField {
elem := q[0]
q = q[1:]
return elem
}
func push(q []reflect.StructField, r reflect.StructField) {
q = append(q, r)
}
// travels anonymous embedded fields in a bfs manner to give us a list of all
// member fields of a struct
// simple Tree based traversal will take place because embbedded fields cannot
// form a cycle
func addStructFields(fields []reflect.StructField, golangType reflect.Type) []reflect.StructField {
bfsQueue := list.New()
for i := 0; i < golangType.NumField(); i++ {
bfsQueue.PushBack(golangType.Field(i))
}
for bfsQueue.Len() > 0 {
front := bfsQueue.Front()
field := front.Value.(reflect.StructField)
bfsQueue.Remove(front)
if !field.Anonymous {
fields = append(fields, field)
continue
}
for i := 0; i < field.Type.NumField(); i++ {
bfsQueue.PushBack(field.Type.Field(i))
}
}
return fields
}
// TODO: add doc string explaining numHistoryOccurances // TODO: add doc string explaining numHistoryOccurances
func toProperity(golangType reflect.Type, traceSet map[reflect.Type]struct{}, traceSlice []reflect.Type) (*Property, error) { func toProperity(golangType reflect.Type, traceSet map[reflect.Type]struct{}, traceSlice []reflect.Type) (*Property, error) {
traceSlice = append(traceSlice, golangType)
// *Struct and Struct generate identical json schemas // *Struct and Struct generate identical json schemas
if golangType.Kind() == reflect.Pointer { if golangType.Kind() == reflect.Pointer {
return toProperity(golangType.Elem(), traceSet, traceSlice) return toProperity(golangType.Elem(), traceSet, traceSlice)
} }
rootJavascriptType, err := javascriptType(golangType) rootJavascriptType, err := javascriptType(golangType)
if err != nil { if err != nil {
return nil, errWithTrace(err.Error(), traceSlice) return nil, err
} }
// case array/slice
var items *Item var items *Item
if golangType.Kind() == reflect.Array || golangType.Kind() == reflect.Slice { if golangType.Kind() == reflect.Array || golangType.Kind() == reflect.Slice {
elemGolangType := golangType.Elem() elemGolangType := golangType.Elem()
elemJavascriptType, err := javascriptType(elemGolangType) elemJavascriptType, err := javascriptType(elemGolangType)
if err != nil { if err != nil {
return nil, errWithTrace(err.Error(), traceSlice) return nil, err
} }
elemProps, err := safeToProperty(elemGolangType, traceSet, traceSlice)
// detect cycles. Fail if a cycle is detected
// TODO: Add references here for cycles
_, ok := traceSet[elemGolangType]
if ok {
fmt.Println("[DEBUG] traceSet: ", traceSet)
return nil, errWithTrace("cycle detected", traceSlice)
}
// add current child field to history
traceSet[elemGolangType] = struct{}{}
elemProps, err := toProperity(elemGolangType, traceSet, traceSlice)
if err != nil { if err != nil {
return nil, errWithTrace(err.Error(), traceSlice) return nil, err
} }
items = &Item{ items = &Item{
// TODO: Add a test for slice of object // TODO: Add a test for slice of object
Type: elemJavascriptType, Type: elemJavascriptType,
Properities: elemProps.Properities, Properities: elemProps.Properities,
} }
} }
// var additionalProperties *Property // case map
// if golangType.Kind() == reflect.Map { var additionalProperties *Property
// additionalProperties = if golangType.Kind() == reflect.Map {
// } if golangType.Key().Kind() != reflect.String {
return nil, errWithTrace("only string keyed maps allowed", traceSlice)
}
additionalProperties, err = safeToProperty(golangType.Elem(), traceSet, traceSlice)
if err != nil {
return nil, err
}
}
// case struct
properities := map[string]*Property{} properities := map[string]*Property{}
if golangType.Kind() == reflect.Struct { if golangType.Kind() == reflect.Struct {
for i := 0; i < golangType.NumField(); i++ { children := []reflect.StructField{}
child := golangType.Field(i) children = addStructFields(children, golangType)
for _, child := range children {
// compute child properties // compute child properties
childJsonTag := child.Tag.Get("json") childJsonTag := child.Tag.Get("json")
childName := strings.Split(childJsonTag, ",")[0] childName := strings.Split(childJsonTag, ",")[0]
@ -157,36 +213,19 @@ func toProperity(golangType reflect.Type, traceSet map[reflect.Type]struct{}, tr
continue continue
} }
// detect cycles. Fail if a cycle is detected
// TODO: Add references here for cycles
_, ok := traceSet[child.Type]
if ok {
fmt.Println("[DEBUG] traceSet: ", traceSet)
return nil, errWithTrace("cycle detected", traceSlice)
}
// add current child field to history
traceSet[child.Type] = struct{}{}
// recursively compute properties for this child field // recursively compute properties for this child field
fieldProps, err := toProperity(child.Type, traceSet, traceSlice) fieldProps, err := safeToProperty(child.Type, traceSet, traceSlice)
if err != nil { if err != nil {
return nil, errWithTrace(err.Error(), traceSlice) return nil, errWithTrace(err.Error(), traceSlice)
} }
// traversal complete, delete child from history
delete(traceSet, child.Type)
properities[childName] = fieldProps properities[childName] = fieldProps
} }
} }
traceSlice = traceSlice[:len(traceSlice)-1]
return &Property{ return &Property{
Type: rootJavascriptType, Type: rootJavascriptType,
Items: items, Items: items,
Properities: properities, Properities: properities,
AdditionalProperities: additionalProperties,
}, nil }, nil
} }

View File

@ -172,9 +172,7 @@ func TestSliceOfObjectsSchema(t *testing.T) {
} }
type Story struct { type Story struct {
Hero Person `json:"hero"` Plot Plot `json:"plot"`
Villian Person `json:"villian"`
Plot Plot `json:"plot"`
} }
elem := Story{} elem := Story{}
@ -189,17 +187,6 @@ func TestSliceOfObjectsSchema(t *testing.T) {
`{ `{
"type": "object", "type": "object",
"properties": { "properties": {
"hero": {
"type": "object",
"properties": {
"age": {
"type": "number"
},
"name": {
"type": "string"
}
}
},
"plot": { "plot": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -218,15 +205,57 @@ func TestSliceOfObjectsSchema(t *testing.T) {
} }
} }
} }
}, }
"villian": { }
}`
fmt.Println("[DEBUG] actual: ", string(jsonSchema))
fmt.Println("[DEBUG] expected: ", expected)
assert.Equal(t, expected, string(jsonSchema))
}
func TestMapOfObjectsSchema(t *testing.T) {
type Person struct {
Name string `json:"name"`
Age int `json:"age,omitempty"`
}
type Plot struct {
Events map[string]Person `json:"events"`
}
type Story struct {
Plot Plot `json:"plot"`
}
elem := Story{}
schema, err := NewSchema(reflect.TypeOf(elem))
assert.NoError(t, err)
jsonSchema, err := json.MarshalIndent(schema, " ", " ")
assert.NoError(t, err)
expected :=
`{
"type": "object",
"properties": {
"plot": {
"type": "object", "type": "object",
"properties": { "properties": {
"age": { "events": {
"type": "number" "type": "object",
}, "additionalProperties": {
"name": { "type": "object",
"type": "string" "properties": {
"age": {
"type": "number"
},
"name": {
"type": "string"
}
}
}
} }
} }
} }