From 5018059444f907221ecffb20f7f1ecb2d6eddf16 Mon Sep 17 00:00:00 2001 From: Pieter Noordhuis Date: Tue, 24 Oct 2023 13:12:36 +0200 Subject: [PATCH] Library to convert config.Value to Go struct (#904) ## Changes Now that we have a new YAML loader (see #828), we need code to turn this into our Go structs. ## Tests New unit tests pass. Confirmed that we can replace our existing loader/converter with this one and that existing unit tests for bundle loading still pass. --- libs/config/convert/error.go | 16 + libs/config/convert/struct_info.go | 87 +++++ libs/config/convert/struct_info_test.go | 89 +++++ libs/config/convert/to_typed.go | 224 ++++++++++++ libs/config/convert/to_typed_test.go | 430 ++++++++++++++++++++++++ libs/config/kind.go | 64 ++++ libs/config/value.go | 113 +++++-- 7 files changed, 998 insertions(+), 25 deletions(-) create mode 100644 libs/config/convert/error.go create mode 100644 libs/config/convert/struct_info.go create mode 100644 libs/config/convert/struct_info_test.go create mode 100644 libs/config/convert/to_typed.go create mode 100644 libs/config/convert/to_typed_test.go create mode 100644 libs/config/kind.go diff --git a/libs/config/convert/error.go b/libs/config/convert/error.go new file mode 100644 index 00000000..b55668d6 --- /dev/null +++ b/libs/config/convert/error.go @@ -0,0 +1,16 @@ +package convert + +import ( + "fmt" + + "github.com/databricks/cli/libs/config" +) + +type TypeError struct { + value config.Value + msg string +} + +func (e TypeError) Error() string { + return fmt.Sprintf("%s: %s", e.value.Location(), e.msg) +} diff --git a/libs/config/convert/struct_info.go b/libs/config/convert/struct_info.go new file mode 100644 index 00000000..367b9ecd --- /dev/null +++ b/libs/config/convert/struct_info.go @@ -0,0 +1,87 @@ +package convert + +import ( + "reflect" + "strings" + "sync" +) + +// structInfo holds the type information we need to efficiently +// convert data from a [config.Value] to a Go struct. +type structInfo struct { + // Fields maps the JSON-name of the field to the field's index for use with [FieldByIndex]. + Fields map[string][]int +} + +// structInfoCache caches type information. +var structInfoCache = make(map[reflect.Type]structInfo) + +// structInfoCacheLock guards concurrent access to structInfoCache. +var structInfoCacheLock sync.Mutex + +// getStructInfo returns the [structInfo] for the given type. +// It lazily populates a cache, so the first call for a given +// type is slower than subsequent calls for that same type. +func getStructInfo(typ reflect.Type) structInfo { + structInfoCacheLock.Lock() + defer structInfoCacheLock.Unlock() + + si, ok := structInfoCache[typ] + if !ok { + si = buildStructInfo(typ) + structInfoCache[typ] = si + } + + return si +} + +// buildStructInfo populates a new [structInfo] for the given type. +func buildStructInfo(typ reflect.Type) structInfo { + var out = structInfo{ + Fields: make(map[string][]int), + } + + // Queue holds the indexes of the structs to visit. + // It is initialized with a single empty slice to visit the top level struct. + var queue [][]int = [][]int{{}} + for i := 0; i < len(queue); i++ { + prefix := queue[i] + + // Traverse embedded anonymous types (if prefix is non-empty). + styp := typ + if len(prefix) > 0 { + styp = styp.FieldByIndex(prefix).Type + } + + // Dereference pointer type. + if styp.Kind() == reflect.Pointer { + styp = styp.Elem() + } + + nf := styp.NumField() + for j := 0; j < nf; j++ { + sf := styp.Field(j) + + // Recurse into anonymous fields. + if sf.Anonymous { + queue = append(queue, append(prefix, sf.Index...)) + continue + } + + name, _, _ := strings.Cut(sf.Tag.Get("json"), ",") + if name == "" || name == "-" { + continue + } + + // Top level fields always take precedence. + // Therefore, if it is already set, we ignore it. + if _, ok := out.Fields[name]; ok { + continue + } + + out.Fields[name] = append(prefix, sf.Index...) + } + } + + return out +} diff --git a/libs/config/convert/struct_info_test.go b/libs/config/convert/struct_info_test.go new file mode 100644 index 00000000..3079958b --- /dev/null +++ b/libs/config/convert/struct_info_test.go @@ -0,0 +1,89 @@ +package convert + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStructInfoPlain(t *testing.T) { + type Tmp struct { + Foo string `json:"foo"` + Bar string `json:"bar,omitempty"` + + // Baz must be skipped. + Baz string `json:""` + + // Qux must be skipped. + Qux string `json:"-"` + } + + si := getStructInfo(reflect.TypeOf(Tmp{})) + assert.Len(t, si.Fields, 2) + assert.Equal(t, []int{0}, si.Fields["foo"]) + assert.Equal(t, []int{1}, si.Fields["bar"]) +} + +func TestStructInfoAnonymousByValue(t *testing.T) { + type Bar struct { + Bar string `json:"bar"` + } + + type Foo struct { + Foo string `json:"foo"` + Bar + } + + type Tmp struct { + Foo + } + + si := getStructInfo(reflect.TypeOf(Tmp{})) + assert.Len(t, si.Fields, 2) + assert.Equal(t, []int{0, 0}, si.Fields["foo"]) + assert.Equal(t, []int{0, 1, 0}, si.Fields["bar"]) +} + +func TestStructInfoAnonymousByValuePrecedence(t *testing.T) { + type Bar struct { + Bar string `json:"bar"` + } + + type Foo struct { + Foo string `json:"foo"` + Bar + } + + type Tmp struct { + // "foo" comes from [Foo]. + Foo + // "bar" comes from [Bar] directly, not through [Foo]. + Bar + } + + si := getStructInfo(reflect.TypeOf(Tmp{})) + assert.Len(t, si.Fields, 2) + assert.Equal(t, []int{0, 0}, si.Fields["foo"]) + assert.Equal(t, []int{1, 0}, si.Fields["bar"]) +} + +func TestStructInfoAnonymousByPointer(t *testing.T) { + type Bar struct { + Bar string `json:"bar"` + } + + type Foo struct { + Foo string `json:"foo"` + *Bar + } + + type Tmp struct { + *Foo + } + + si := getStructInfo(reflect.TypeOf(Tmp{})) + assert.Len(t, si.Fields, 2) + assert.Equal(t, []int{0, 0}, si.Fields["foo"]) + assert.Equal(t, []int{0, 1, 0}, si.Fields["bar"]) +} diff --git a/libs/config/convert/to_typed.go b/libs/config/convert/to_typed.go new file mode 100644 index 00000000..9915d30a --- /dev/null +++ b/libs/config/convert/to_typed.go @@ -0,0 +1,224 @@ +package convert + +import ( + "fmt" + "reflect" + "strconv" + + "github.com/databricks/cli/libs/config" +) + +func ToTyped(dst any, src config.Value) error { + dstv := reflect.ValueOf(dst) + + // Dereference pointer if necessary + for dstv.Kind() == reflect.Pointer { + if dstv.IsNil() { + dstv.Set(reflect.New(dstv.Type().Elem())) + } + dstv = dstv.Elem() + } + + // Verify that vv is settable. + if !dstv.CanSet() { + panic("cannot set destination value") + } + + switch dstv.Kind() { + case reflect.Struct: + return toTypedStruct(dstv, src) + case reflect.Map: + return toTypedMap(dstv, src) + case reflect.Slice: + return toTypedSlice(dstv, src) + case reflect.String: + return toTypedString(dstv, src) + case reflect.Bool: + return toTypedBool(dstv, src) + case reflect.Int, reflect.Int32, reflect.Int64: + return toTypedInt(dstv, src) + case reflect.Float32, reflect.Float64: + return toTypedFloat(dstv, src) + } + + return fmt.Errorf("unsupported type: %s", dstv.Kind()) +} + +func toTypedStruct(dst reflect.Value, src config.Value) error { + switch src.Kind() { + case config.KindMap: + info := getStructInfo(dst.Type()) + for k, v := range src.MustMap() { + index, ok := info.Fields[k] + if !ok { + // Ignore unknown fields. + // A warning will be printed later. See PR #904. + continue + } + + // Create intermediate structs embedded as pointer types. + // Code inspired by [reflect.FieldByIndex] implementation. + f := dst + for i, x := range index { + if i > 0 { + if f.Kind() == reflect.Pointer { + if f.IsNil() { + f.Set(reflect.New(f.Type().Elem())) + } + f = f.Elem() + } + } + f = f.Field(x) + } + + err := ToTyped(f.Addr().Interface(), v) + if err != nil { + return err + } + } + + return nil + case config.KindNil: + dst.SetZero() + return nil + } + + return TypeError{ + value: src, + msg: fmt.Sprintf("expected a map, found a %s", src.Kind()), + } +} + +func toTypedMap(dst reflect.Value, src config.Value) error { + switch src.Kind() { + case config.KindMap: + m := src.MustMap() + + // Always overwrite. + dst.Set(reflect.MakeMapWithSize(dst.Type(), len(m))) + for k, v := range m { + kv := reflect.ValueOf(k) + vv := reflect.New(dst.Type().Elem()) + err := ToTyped(vv.Interface(), v) + if err != nil { + return err + } + dst.SetMapIndex(kv, vv.Elem()) + } + return nil + case config.KindNil: + dst.SetZero() + return nil + } + + return TypeError{ + value: src, + msg: fmt.Sprintf("expected a map, found a %s", src.Kind()), + } +} + +func toTypedSlice(dst reflect.Value, src config.Value) error { + switch src.Kind() { + case config.KindSequence: + seq := src.MustSequence() + + // Always overwrite. + dst.Set(reflect.MakeSlice(dst.Type(), len(seq), len(seq))) + for i := range seq { + err := ToTyped(dst.Index(i).Addr().Interface(), seq[i]) + if err != nil { + return err + } + } + return nil + case config.KindNil: + dst.SetZero() + return nil + } + + return TypeError{ + value: src, + msg: fmt.Sprintf("expected a sequence, found a %s", src.Kind()), + } +} + +func toTypedString(dst reflect.Value, src config.Value) error { + switch src.Kind() { + case config.KindString: + dst.SetString(src.MustString()) + return nil + case config.KindBool: + dst.SetString(strconv.FormatBool(src.MustBool())) + return nil + case config.KindInt: + dst.SetString(strconv.FormatInt(src.MustInt(), 10)) + return nil + case config.KindFloat: + dst.SetString(strconv.FormatFloat(src.MustFloat(), 'f', -1, 64)) + return nil + } + + return TypeError{ + value: src, + msg: fmt.Sprintf("expected a string, found a %s", src.Kind()), + } +} + +func toTypedBool(dst reflect.Value, src config.Value) error { + switch src.Kind() { + case config.KindBool: + dst.SetBool(src.MustBool()) + return nil + case config.KindString: + // See https://github.com/go-yaml/yaml/blob/f6f7691b1fdeb513f56608cd2c32c51f8194bf51/decode.go#L684-L693. + switch src.MustString() { + case "y", "Y", "yes", "Yes", "YES", "on", "On", "ON": + dst.SetBool(true) + return nil + case "n", "N", "no", "No", "NO", "off", "Off", "OFF": + dst.SetBool(false) + return nil + } + } + + return TypeError{ + value: src, + msg: fmt.Sprintf("expected a boolean, found a %s", src.Kind()), + } +} + +func toTypedInt(dst reflect.Value, src config.Value) error { + switch src.Kind() { + case config.KindInt: + dst.SetInt(src.MustInt()) + return nil + case config.KindString: + if i64, err := strconv.ParseInt(src.MustString(), 10, 64); err == nil { + dst.SetInt(i64) + return nil + } + } + + return TypeError{ + value: src, + msg: fmt.Sprintf("expected an int, found a %s", src.Kind()), + } +} + +func toTypedFloat(dst reflect.Value, src config.Value) error { + switch src.Kind() { + case config.KindFloat: + dst.SetFloat(src.MustFloat()) + return nil + case config.KindString: + if f64, err := strconv.ParseFloat(src.MustString(), 64); err == nil { + dst.SetFloat(f64) + return nil + } + } + + return TypeError{ + value: src, + msg: fmt.Sprintf("expected a float, found a %s", src.Kind()), + } +} diff --git a/libs/config/convert/to_typed_test.go b/libs/config/convert/to_typed_test.go new file mode 100644 index 00000000..26e17dcc --- /dev/null +++ b/libs/config/convert/to_typed_test.go @@ -0,0 +1,430 @@ +package convert + +import ( + "testing" + + "github.com/databricks/cli/libs/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestToTypedStruct(t *testing.T) { + type Tmp struct { + Foo string `json:"foo"` + Bar string `json:"bar,omitempty"` + + // Baz must be skipped. + Baz string `json:""` + + // Qux must be skipped. + Qux string `json:"-"` + } + + var out Tmp + v := config.V(map[string]config.Value{ + "foo": config.V("bar"), + "bar": config.V("baz"), + }) + + err := ToTyped(&out, v) + require.NoError(t, err) + assert.Equal(t, "bar", out.Foo) + assert.Equal(t, "baz", out.Bar) +} + +func TestToTypedStructOverwrite(t *testing.T) { + type Tmp struct { + Foo string `json:"foo"` + Bar string `json:"bar,omitempty"` + + // Baz must be skipped. + Baz string `json:""` + + // Qux must be skipped. + Qux string `json:"-"` + } + + var out = Tmp{ + Foo: "baz", + Bar: "qux", + } + v := config.V(map[string]config.Value{ + "foo": config.V("bar"), + "bar": config.V("baz"), + }) + + err := ToTyped(&out, v) + require.NoError(t, err) + assert.Equal(t, "bar", out.Foo) + assert.Equal(t, "baz", out.Bar) +} + +func TestToTypedStructAnonymousByValue(t *testing.T) { + type Bar struct { + Bar string `json:"bar"` + } + + type Foo struct { + Foo string `json:"foo"` + Bar + } + + type Tmp struct { + Foo + } + + var out Tmp + v := config.V(map[string]config.Value{ + "foo": config.V("bar"), + "bar": config.V("baz"), + }) + + err := ToTyped(&out, v) + require.NoError(t, err) + assert.Equal(t, "bar", out.Foo.Foo) + assert.Equal(t, "baz", out.Foo.Bar.Bar) +} + +func TestToTypedStructAnonymousByPointer(t *testing.T) { + type Bar struct { + Bar string `json:"bar"` + } + + type Foo struct { + Foo string `json:"foo"` + *Bar + } + + type Tmp struct { + *Foo + } + + var out Tmp + v := config.V(map[string]config.Value{ + "foo": config.V("bar"), + "bar": config.V("baz"), + }) + + err := ToTyped(&out, v) + require.NoError(t, err) + assert.Equal(t, "bar", out.Foo.Foo) + assert.Equal(t, "baz", out.Foo.Bar.Bar) +} + +func TestToTypedStructNil(t *testing.T) { + type Tmp struct { + Foo string `json:"foo"` + } + + var out = Tmp{} + err := ToTyped(&out, config.NilValue) + require.NoError(t, err) + assert.Equal(t, Tmp{}, out) +} + +func TestToTypedStructNilOverwrite(t *testing.T) { + type Tmp struct { + Foo string `json:"foo"` + } + + var out = Tmp{"bar"} + err := ToTyped(&out, config.NilValue) + require.NoError(t, err) + assert.Equal(t, Tmp{}, out) +} + +func TestToTypedMap(t *testing.T) { + var out = map[string]string{} + + v := config.V(map[string]config.Value{ + "key": config.V("value"), + }) + + err := ToTyped(&out, v) + require.NoError(t, err) + assert.Len(t, out, 1) + assert.Equal(t, "value", out["key"]) +} + +func TestToTypedMapOverwrite(t *testing.T) { + var out = map[string]string{ + "foo": "bar", + } + + v := config.V(map[string]config.Value{ + "bar": config.V("qux"), + }) + + err := ToTyped(&out, v) + require.NoError(t, err) + assert.Len(t, out, 1) + assert.Equal(t, "qux", out["bar"]) +} + +func TestToTypedMapWithPointerElement(t *testing.T) { + var out map[string]*string + + v := config.V(map[string]config.Value{ + "key": config.V("value"), + }) + + err := ToTyped(&out, v) + require.NoError(t, err) + assert.Len(t, out, 1) + assert.Equal(t, "value", *out["key"]) +} + +func TestToTypedMapNil(t *testing.T) { + var out = map[string]string{} + err := ToTyped(&out, config.NilValue) + require.NoError(t, err) + assert.Nil(t, out) +} + +func TestToTypedMapNilOverwrite(t *testing.T) { + var out = map[string]string{ + "foo": "bar", + } + err := ToTyped(&out, config.NilValue) + require.NoError(t, err) + assert.Nil(t, out) +} + +func TestToTypedSlice(t *testing.T) { + var out []string + + v := config.V([]config.Value{ + config.V("foo"), + config.V("bar"), + }) + + err := ToTyped(&out, v) + require.NoError(t, err) + assert.Len(t, out, 2) + assert.Equal(t, "foo", out[0]) + assert.Equal(t, "bar", out[1]) +} + +func TestToTypedSliceOverwrite(t *testing.T) { + var out = []string{"qux"} + + v := config.V([]config.Value{ + config.V("foo"), + config.V("bar"), + }) + + err := ToTyped(&out, v) + require.NoError(t, err) + assert.Len(t, out, 2) + assert.Equal(t, "foo", out[0]) + assert.Equal(t, "bar", out[1]) +} + +func TestToTypedSliceWithPointerElement(t *testing.T) { + var out []*string + + v := config.V([]config.Value{ + config.V("foo"), + config.V("bar"), + }) + + err := ToTyped(&out, v) + require.NoError(t, err) + assert.Len(t, out, 2) + assert.Equal(t, "foo", *out[0]) + assert.Equal(t, "bar", *out[1]) +} + +func TestToTypedSliceNil(t *testing.T) { + var out []string + err := ToTyped(&out, config.NilValue) + require.NoError(t, err) + assert.Nil(t, out) +} + +func TestToTypedSliceNilOverwrite(t *testing.T) { + var out = []string{"foo"} + err := ToTyped(&out, config.NilValue) + require.NoError(t, err) + assert.Nil(t, out) +} + +func TestToTypedString(t *testing.T) { + var out string + err := ToTyped(&out, config.V("foo")) + require.NoError(t, err) + assert.Equal(t, "foo", out) +} + +func TestToTypedStringOverwrite(t *testing.T) { + var out string = "bar" + err := ToTyped(&out, config.V("foo")) + require.NoError(t, err) + assert.Equal(t, "foo", out) +} + +func TestToTypedStringFromBool(t *testing.T) { + var out string + err := ToTyped(&out, config.V(true)) + require.NoError(t, err) + assert.Equal(t, "true", out) +} + +func TestToTypedStringFromInt(t *testing.T) { + var out string + err := ToTyped(&out, config.V(123)) + require.NoError(t, err) + assert.Equal(t, "123", out) +} + +func TestToTypedStringFromFloat(t *testing.T) { + var out string + err := ToTyped(&out, config.V(1.2)) + require.NoError(t, err) + assert.Equal(t, "1.2", out) +} + +func TestToTypedBool(t *testing.T) { + var out bool + err := ToTyped(&out, config.V(true)) + require.NoError(t, err) + assert.Equal(t, true, out) +} + +func TestToTypedBoolOverwrite(t *testing.T) { + var out bool = true + err := ToTyped(&out, config.V(false)) + require.NoError(t, err) + assert.Equal(t, false, out) +} + +func TestToTypedBoolFromString(t *testing.T) { + var out bool + + // True-ish + for _, v := range []string{"y", "yes", "on"} { + err := ToTyped(&out, config.V(v)) + require.NoError(t, err) + assert.Equal(t, true, out) + } + + // False-ish + for _, v := range []string{"n", "no", "off"} { + err := ToTyped(&out, config.V(v)) + require.NoError(t, err) + assert.Equal(t, false, out) + } + + // Other + err := ToTyped(&out, config.V("${var.foo}")) + require.Error(t, err) +} + +func TestToTypedInt(t *testing.T) { + var out int + err := ToTyped(&out, config.V(1234)) + require.NoError(t, err) + assert.Equal(t, int(1234), out) +} + +func TestToTypedInt32(t *testing.T) { + var out32 int32 + err := ToTyped(&out32, config.V(1235)) + require.NoError(t, err) + assert.Equal(t, int32(1235), out32) +} + +func TestToTypedInt64(t *testing.T) { + var out64 int64 + err := ToTyped(&out64, config.V(1236)) + require.NoError(t, err) + assert.Equal(t, int64(1236), out64) +} + +func TestToTypedIntOverwrite(t *testing.T) { + var out int = 123 + err := ToTyped(&out, config.V(1234)) + require.NoError(t, err) + assert.Equal(t, int(1234), out) +} + +func TestToTypedInt32Overwrite(t *testing.T) { + var out32 int32 = 123 + err := ToTyped(&out32, config.V(1234)) + require.NoError(t, err) + assert.Equal(t, int32(1234), out32) +} + +func TestToTypedInt64Overwrite(t *testing.T) { + var out64 int64 = 123 + err := ToTyped(&out64, config.V(1234)) + require.NoError(t, err) + assert.Equal(t, int64(1234), out64) +} + +func TestToTypedIntFromStringError(t *testing.T) { + var out int + err := ToTyped(&out, config.V("abc")) + require.Error(t, err) +} + +func TestToTypedIntFromStringInt(t *testing.T) { + var out int + err := ToTyped(&out, config.V("123")) + require.NoError(t, err) + assert.Equal(t, int(123), out) +} + +func TestToTypedFloat32(t *testing.T) { + var out float32 + err := ToTyped(&out, config.V(float32(1.0))) + require.NoError(t, err) + assert.Equal(t, float32(1.0), out) +} + +func TestToTypedFloat64(t *testing.T) { + var out float64 + err := ToTyped(&out, config.V(float64(1.0))) + require.NoError(t, err) + assert.Equal(t, float64(1.0), out) +} + +func TestToTypedFloat32Overwrite(t *testing.T) { + var out float32 = 1.0 + err := ToTyped(&out, config.V(float32(2.0))) + require.NoError(t, err) + assert.Equal(t, float32(2.0), out) +} + +func TestToTypedFloat64Overwrite(t *testing.T) { + var out float64 = 1.0 + err := ToTyped(&out, config.V(float64(2.0))) + require.NoError(t, err) + assert.Equal(t, float64(2.0), out) +} + +func TestToTypedFloat32FromStringError(t *testing.T) { + var out float32 + err := ToTyped(&out, config.V("abc")) + require.Error(t, err) +} + +func TestToTypedFloat64FromStringError(t *testing.T) { + var out float64 + err := ToTyped(&out, config.V("abc")) + require.Error(t, err) +} + +func TestToTypedFloat32FromString(t *testing.T) { + var out float32 + err := ToTyped(&out, config.V("1.2")) + require.NoError(t, err) + assert.Equal(t, float32(1.2), out) +} + +func TestToTypedFloat64FromString(t *testing.T) { + var out float64 + err := ToTyped(&out, config.V("1.2")) + require.NoError(t, err) + assert.Equal(t, float64(1.2), out) +} diff --git a/libs/config/kind.go b/libs/config/kind.go new file mode 100644 index 00000000..5ed1a665 --- /dev/null +++ b/libs/config/kind.go @@ -0,0 +1,64 @@ +package config + +import "time" + +type Kind int + +const ( + // Invalid is the zero value of Kind. + KindInvalid Kind = iota + KindMap + KindSequence + KindNil + KindString + KindBool + KindInt + KindFloat + KindTime +) + +func kindOf(v any) Kind { + switch v.(type) { + case map[string]Value: + return KindMap + case []Value: + return KindSequence + case nil: + return KindNil + case string: + return KindString + case bool: + return KindBool + case int, int32, int64: + return KindInt + case float32, float64: + return KindFloat + case time.Time: + return KindTime + default: + panic("not handled") + } +} + +func (k Kind) String() string { + switch k { + case KindMap: + return "map" + case KindSequence: + return "sequence" + case KindNil: + return "nil" + case KindString: + return "string" + case KindBool: + return "bool" + case KindInt: + return "int" + case KindFloat: + return "float" + case KindTime: + return "time" + default: + return "invalid" + } +} diff --git a/libs/config/value.go b/libs/config/value.go index 994aec38..c77f8147 100644 --- a/libs/config/value.go +++ b/libs/config/value.go @@ -1,9 +1,14 @@ package config -import "time" +import ( + "fmt" + "time" +) type Value struct { v any + + k Kind l Location // Whether or not this value is an anchor. @@ -12,12 +17,23 @@ type Value struct { } // NilValue is equal to the zero-value of Value. -var NilValue = Value{} +var NilValue = Value{ + k: KindNil, +} + +// V constructs a new Value with the given value. +func V(v any) Value { + return Value{ + v: v, + k: kindOf(v), + } +} // NewValue constructs a new Value with the given value and location. func NewValue(v any, loc Location) Value { return Value{ v: v, + k: kindOf(v), l: loc, } } @@ -27,45 +43,47 @@ func (v Value) AsMap() (map[string]Value, bool) { return m, ok } +func (v Value) Kind() Kind { + return v.k +} + func (v Value) Location() Location { return v.l } func (v Value) AsAny() any { - switch vv := v.v.(type) { - case map[string]Value: - m := make(map[string]any) + switch v.k { + case KindInvalid: + panic("invoked AsAny on invalid value") + case KindMap: + vv := v.v.(map[string]Value) + m := make(map[string]any, len(vv)) for k, v := range vv { m[k] = v.AsAny() } return m - case []Value: + case KindSequence: + vv := v.v.([]Value) a := make([]any, len(vv)) for i, v := range vv { a[i] = v.AsAny() } return a - case string: - return vv - case bool: - return vv - case int: - return vv - case int32: - return vv - case int64: - return vv - case float32: - return vv - case float64: - return vv - case time.Time: - return vv - case nil: - return nil + case KindNil: + return v.v + case KindString: + return v.v + case KindBool: + return v.v + case KindInt: + return v.v + case KindFloat: + return v.v + case KindTime: + return v.v default: // Panic because we only want to deal with known types. - panic("not handled") + panic(fmt.Sprintf("invalid kind: %d", v.k)) } } @@ -99,6 +117,7 @@ func (v Value) Index(i int) Value { func (v Value) MarkAnchor() Value { return Value{ v: v.v, + k: v.k, l: v.l, anchor: true, @@ -108,3 +127,47 @@ func (v Value) MarkAnchor() Value { func (v Value) IsAnchor() bool { return v.anchor } + +func (v Value) MustMap() map[string]Value { + return v.v.(map[string]Value) +} + +func (v Value) MustSequence() []Value { + return v.v.([]Value) +} + +func (v Value) MustString() string { + return v.v.(string) +} + +func (v Value) MustBool() bool { + return v.v.(bool) +} + +func (v Value) MustInt() int64 { + switch vv := v.v.(type) { + case int: + return int64(vv) + case int32: + return int64(vv) + case int64: + return int64(vv) + default: + panic("not an int") + } +} + +func (v Value) MustFloat() float64 { + switch vv := v.v.(type) { + case float32: + return float64(vv) + case float64: + return float64(vv) + default: + panic("not a float") + } +} + +func (v Value) MustTime() time.Time { + return v.v.(time.Time) +}