diff --git a/libs/dyn/dynassert/assert.go b/libs/dyn/dynassert/assert.go index dc6676ca..f667b08c 100644 --- a/libs/dyn/dynassert/assert.go +++ b/libs/dyn/dynassert/assert.go @@ -111,3 +111,7 @@ func PanicsWithError(t assert.TestingT, errString string, f func(), msgAndArgs . func NotPanics(t assert.TestingT, f func(), msgAndArgs ...interface{}) bool { return assert.NotPanics(t, f, msgAndArgs...) } + +func JSONEq(t assert.TestingT, expected string, actual string, msgAndArgs ...interface{}) bool { + return assert.JSONEq(t, expected, actual, msgAndArgs...) +} diff --git a/libs/dyn/jsonsaver/encoder.go b/libs/dyn/jsonsaver/encoder.go new file mode 100644 index 00000000..66997e96 --- /dev/null +++ b/libs/dyn/jsonsaver/encoder.go @@ -0,0 +1,39 @@ +package jsonsaver + +import ( + "bytes" + "encoding/json" +) + +// The encoder type encapsulates a [json.Encoder] and its target buffer. +// Escaping of HTML characters in the output is disabled. +type encoder struct { + *json.Encoder + *bytes.Buffer +} + +func newEncoder() encoder { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + + // By default, json.Encoder escapes HTML characters, converting symbols like '<' to '\u003c'. + // This behavior helps prevent XSS attacks when JSON is embedded within HTML. + // However, we disable this feature since we're not dealing with HTML context. + // Keeping the escapes enabled would result in unnecessary differences when processing JSON payloads + // that already contain escaped characters. + enc.SetEscapeHTML(false) + return encoder{enc, &buf} +} + +func marshalNoEscape(v any) ([]byte, error) { + enc := newEncoder() + err := enc.Encode(v) + return enc.Bytes(), err +} + +func marshalIndentNoEscape(v any, prefix, indent string) ([]byte, error) { + enc := newEncoder() + enc.SetIndent(prefix, indent) + err := enc.Encode(v) + return enc.Bytes(), err +} diff --git a/libs/dyn/jsonsaver/encoder_test.go b/libs/dyn/jsonsaver/encoder_test.go new file mode 100644 index 00000000..d1b7d017 --- /dev/null +++ b/libs/dyn/jsonsaver/encoder_test.go @@ -0,0 +1,41 @@ +package jsonsaver + +import ( + "testing" + + assert "github.com/databricks/cli/libs/dyn/dynassert" +) + +func TestEncoder_MarshalNoEscape(t *testing.T) { + out, err := marshalNoEscape("1 < 2") + if !assert.NoError(t, err) { + return + } + + // Confirm the output. + assert.JSONEq(t, `"1 < 2"`, string(out)) + + // Confirm that HTML escaping is disabled. + assert.NotContains(t, string(out), "\\u003c") + + // Confirm that the encoder writes a trailing newline. + assert.Contains(t, string(out), "\n") +} + +func TestEncoder_MarshalIndentNoEscape(t *testing.T) { + out, err := marshalIndentNoEscape([]string{"1 < 2", "2 < 3"}, "", " ") + if !assert.NoError(t, err) { + return + } + + // Confirm the output. + assert.JSONEq(t, `["1 < 2", "2 < 3"]`, string(out)) + + // Confirm that HTML escaping is disabled. + assert.NotContains(t, string(out), "\\u003c") + + // Confirm that the encoder performs indenting and writes a trailing newline. + assert.Contains(t, string(out), "[\n") + assert.Contains(t, string(out), " \"1 < 2\",\n") + assert.Contains(t, string(out), "]\n") +} diff --git a/libs/dyn/jsonsaver/marshal.go b/libs/dyn/jsonsaver/marshal.go new file mode 100644 index 00000000..a78a68f2 --- /dev/null +++ b/libs/dyn/jsonsaver/marshal.go @@ -0,0 +1,89 @@ +package jsonsaver + +import ( + "bytes" + "fmt" + + "github.com/databricks/cli/libs/dyn" +) + +// Marshal is a version of [json.Marshal] for [dyn.Value]. +// +// Objects in the output retain the order of keys as they appear in the underlying [dyn.Value]. +// The output does not escape HTML characters in strings. +func Marshal(v dyn.Value) ([]byte, error) { + return marshalNoEscape(wrap{v}) +} + +// MarshalIndent is a version of [json.MarshalIndent] for [dyn.Value]. +// +// Objects in the output retain the order of keys as they appear in the underlying [dyn.Value]. +// The output does not escape HTML characters in strings. +func MarshalIndent(v dyn.Value, prefix, indent string) ([]byte, error) { + return marshalIndentNoEscape(wrap{v}, prefix, indent) +} + +// Wrapper type for [dyn.Value] to expose the [json.Marshaler] interface. +type wrap struct { + v dyn.Value +} + +// MarshalJSON implements the [json.Marshaler] interface for the [dyn.Value] wrapper type. +func (w wrap) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + if err := marshalValue(&buf, w.v); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// marshalValue recursively writes JSON for a [dyn.Value] to the buffer. +func marshalValue(buf *bytes.Buffer, v dyn.Value) error { + switch v.Kind() { + case dyn.KindString, dyn.KindBool, dyn.KindInt, dyn.KindFloat, dyn.KindTime, dyn.KindNil: + out, err := marshalNoEscape(v.AsAny()) + if err != nil { + return err + } + + // The encoder writes a trailing newline, so we need to remove it + // to avoid adding extra newlines when embedding this JSON. + out = out[:len(out)-1] + buf.Write(out) + case dyn.KindMap: + buf.WriteByte('{') + for i, pair := range v.MustMap().Pairs() { + if i > 0 { + buf.WriteByte(',') + } + // Require keys to be strings. + if pair.Key.Kind() != dyn.KindString { + return fmt.Errorf("map key must be a string, got %s", pair.Key.Kind()) + } + // Marshal the key + if err := marshalValue(buf, pair.Key); err != nil { + return err + } + buf.WriteByte(':') + // Marshal the value + if err := marshalValue(buf, pair.Value); err != nil { + return err + } + } + buf.WriteByte('}') + case dyn.KindSequence: + buf.WriteByte('[') + for i, item := range v.MustSequence() { + if i > 0 { + buf.WriteByte(',') + } + if err := marshalValue(buf, item); err != nil { + return err + } + } + buf.WriteByte(']') + default: + return fmt.Errorf("unsupported kind: %d", v.Kind()) + } + return nil +} diff --git a/libs/dyn/jsonsaver/marshal_test.go b/libs/dyn/jsonsaver/marshal_test.go new file mode 100644 index 00000000..0b6a3428 --- /dev/null +++ b/libs/dyn/jsonsaver/marshal_test.go @@ -0,0 +1,100 @@ +package jsonsaver + +import ( + "testing" + + "github.com/databricks/cli/libs/dyn" + assert "github.com/databricks/cli/libs/dyn/dynassert" +) + +func TestMarshal_String(t *testing.T) { + b, err := Marshal(dyn.V("string")) + if assert.NoError(t, err) { + assert.JSONEq(t, `"string"`, string(b)) + } +} + +func TestMarshal_Bool(t *testing.T) { + b, err := Marshal(dyn.V(true)) + if assert.NoError(t, err) { + assert.JSONEq(t, `true`, string(b)) + } +} + +func TestMarshal_Int(t *testing.T) { + b, err := Marshal(dyn.V(42)) + if assert.NoError(t, err) { + assert.JSONEq(t, `42`, string(b)) + } +} + +func TestMarshal_Float(t *testing.T) { + b, err := Marshal(dyn.V(42.1)) + if assert.NoError(t, err) { + assert.JSONEq(t, `42.1`, string(b)) + } +} + +func TestMarshal_Time(t *testing.T) { + b, err := Marshal(dyn.V(dyn.MustTime("2021-01-01T00:00:00Z"))) + if assert.NoError(t, err) { + assert.JSONEq(t, `"2021-01-01T00:00:00Z"`, string(b)) + } +} + +func TestMarshal_Map(t *testing.T) { + m := dyn.NewMapping() + m.Set(dyn.V("key1"), dyn.V("value1")) + m.Set(dyn.V("key2"), dyn.V("value2")) + + b, err := Marshal(dyn.V(m)) + if assert.NoError(t, err) { + assert.JSONEq(t, `{"key1":"value1","key2":"value2"}`, string(b)) + } +} + +func TestMarshal_Sequence(t *testing.T) { + var s []dyn.Value + s = append(s, dyn.V("value1")) + s = append(s, dyn.V("value2")) + + b, err := Marshal(dyn.V(s)) + if assert.NoError(t, err) { + assert.JSONEq(t, `["value1","value2"]`, string(b)) + } +} + +func TestMarshal_Complex(t *testing.T) { + map1 := dyn.NewMapping() + map1.Set(dyn.V("str1"), dyn.V("value1")) + map1.Set(dyn.V("str2"), dyn.V("value2")) + + seq1 := []dyn.Value{} + seq1 = append(seq1, dyn.V("value1")) + seq1 = append(seq1, dyn.V("value2")) + + root := dyn.NewMapping() + root.Set(dyn.V("map1"), dyn.V(map1)) + root.Set(dyn.V("seq1"), dyn.V(seq1)) + + // Marshal without indent. + b, err := Marshal(dyn.V(root)) + if assert.NoError(t, err) { + assert.Equal(t, `{"map1":{"str1":"value1","str2":"value2"},"seq1":["value1","value2"]}`+"\n", string(b)) + } + + // Marshal with indent. + b, err = MarshalIndent(dyn.V(root), "", " ") + if assert.NoError(t, err) { + assert.Equal(t, `{ + "map1": { + "str1": "value1", + "str2": "value2" + }, + "seq1": [ + "value1", + "value2" + ] +}`+"\n", string(b)) + } +}