From 58ab2f2cfe8f795dec8ec589924f019b729265f0 Mon Sep 17 00:00:00 2001 From: Andrew Nester Date: Fri, 13 Sep 2024 13:27:53 +0200 Subject: [PATCH] Added JSON input validation for CLI commands --- libs/dyn/convert/to_typed.go | 19 +++++- libs/dyn/jsonloader/json.go | 21 +++++++ libs/dyn/jsonloader/json_test.go | 53 ++++++++++++++++ libs/dyn/jsonloader/loader.go | 99 ++++++++++++++++++++++++++++++ libs/flags/json_flag.go | 26 +++++++- libs/flags/json_flag_test.go | 101 ++++++++++++++++++++++++++++++- 6 files changed, 313 insertions(+), 6 deletions(-) create mode 100644 libs/dyn/jsonloader/json.go create mode 100644 libs/dyn/jsonloader/json_test.go create mode 100644 libs/dyn/jsonloader/loader.go diff --git a/libs/dyn/convert/to_typed.go b/libs/dyn/convert/to_typed.go index 839d0111a..5d970dd02 100644 --- a/libs/dyn/convert/to_typed.go +++ b/libs/dyn/convert/to_typed.go @@ -221,10 +221,10 @@ func toTypedBool(dst reflect.Value, src dyn.Value) error { case dyn.KindString: // See https://github.com/go-yaml/yaml/blob/f6f7691b1fdeb513f56608cd2c32c51f8194bf51/decode.go#L684-L693. switch src.MustString() { - case "y", "Y", "yes", "Yes", "YES", "on", "On", "ON": + case "y", "Y", "yes", "Yes", "YES", "on", "On", "ON", "true": dst.SetBool(true) return nil - case "n", "N", "no", "No", "NO", "off", "Off", "OFF": + case "n", "N", "no", "No", "NO", "off", "Off", "OFF", "false": dst.SetBool(false) return nil } @@ -246,6 +246,17 @@ func toTypedInt(dst reflect.Value, src dyn.Value) error { case dyn.KindInt: dst.SetInt(src.MustInt()) return nil + case dyn.KindFloat: + v := src.MustFloat() + if canConvertToInt(v) { + dst.SetInt(int64(src.MustFloat())) + return nil + } + + return TypeError{ + value: src, + msg: fmt.Sprintf("expected an int, found a %s", src.Kind()), + } case dyn.KindString: if i64, err := strconv.ParseInt(src.MustString(), 10, 64); err == nil { dst.SetInt(i64) @@ -264,6 +275,10 @@ func toTypedInt(dst reflect.Value, src dyn.Value) error { } } +func canConvertToInt(v float64) bool { + return v == float64(int(v)) +} + func toTypedFloat(dst reflect.Value, src dyn.Value) error { switch src.Kind() { case dyn.KindFloat: diff --git a/libs/dyn/jsonloader/json.go b/libs/dyn/jsonloader/json.go new file mode 100644 index 000000000..36d594fb9 --- /dev/null +++ b/libs/dyn/jsonloader/json.go @@ -0,0 +1,21 @@ +package jsonloader + +import ( + "encoding/json" + + "github.com/databricks/cli/libs/dyn" +) + +func LoadJSON(data []byte) (dyn.Value, error) { + var root map[string]interface{} + err := json.Unmarshal(data, &root) + if err != nil { + return dyn.InvalidValue, err + } + + loc := dyn.Location{ + Line: 1, + Column: 1, + } + return newLoader().load(&root, loc) +} diff --git a/libs/dyn/jsonloader/json_test.go b/libs/dyn/jsonloader/json_test.go new file mode 100644 index 000000000..f97739c6c --- /dev/null +++ b/libs/dyn/jsonloader/json_test.go @@ -0,0 +1,53 @@ +package jsonloader + +import ( + "testing" + + "github.com/databricks/cli/libs/dyn/convert" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/require" +) + +const jsonData = ` +{ + "job_id": 123, + "new_settings": { + "name": "xxx", + "email_notifications": { + "on_start": [], + "on_success": [], + "on_failure": [] + }, + "webhook_notifications": { + "on_start": [], + "on_failure": [] + }, + "notification_settings": { + "no_alert_for_skipped_runs": true, + "no_alert_for_canceled_runs": true + }, + "timeout_seconds": 0, + "max_concurrent_runs": 1, + "tasks": [ + { + "task_key": "xxx", + "email_notifications": {}, + "notification_settings": {}, + "timeout_seconds": 0, + "max_retries": 0, + "min_retry_interval_millis": 0, + "retry_on_timeout": "true" + } + ] + } +} +` + +func TestJsonLoader(t *testing.T) { + v, err := LoadJSON([]byte(jsonData)) + require.NoError(t, err) + + var r jobs.ResetJob + err = convert.ToTyped(&r, v) + require.NoError(t, err) +} diff --git a/libs/dyn/jsonloader/loader.go b/libs/dyn/jsonloader/loader.go new file mode 100644 index 000000000..6f82eb679 --- /dev/null +++ b/libs/dyn/jsonloader/loader.go @@ -0,0 +1,99 @@ +package jsonloader + +import ( + "fmt" + "reflect" + + "github.com/databricks/cli/libs/dyn" +) + +type loader struct { +} + +func newLoader() *loader { + return &loader{} +} + +func errorf(loc dyn.Location, format string, args ...interface{}) error { + return fmt.Errorf("json (%s): %s", loc, fmt.Sprintf(format, args...)) +} + +func (d *loader) load(node any, loc dyn.Location) (dyn.Value, error) { + var value dyn.Value + var err error + + if node == nil { + return dyn.NilValue, nil + } + + if reflect.TypeOf(node).Kind() == reflect.Ptr { + return d.load(reflect.ValueOf(node).Elem().Interface(), loc) + } + + switch reflect.TypeOf(node).Kind() { + case reflect.Map: + value, err = d.loadMapping(node.(map[string]interface{}), loc) + case reflect.Slice: + value, err = d.loadSequence(node.([]interface{}), loc) + case reflect.String, reflect.Bool, + reflect.Float64, reflect.Float32, + reflect.Int, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint32, reflect.Uint64: + value, err = d.loadScalar(node, loc) + + default: + return dyn.InvalidValue, errorf(loc, "unknown node kind: %v", reflect.TypeOf(node).Kind()) + } + + if err != nil { + return dyn.InvalidValue, err + } + + return value, nil +} + +func (d *loader) loadScalar(node any, loc dyn.Location) (dyn.Value, error) { + switch reflect.TypeOf(node).Kind() { + case reflect.String: + return dyn.NewValue(node.(string), []dyn.Location{loc}), nil + case reflect.Bool: + return dyn.NewValue(node.(bool), []dyn.Location{loc}), nil + case reflect.Float64, reflect.Float32: + return dyn.NewValue(node.(float64), []dyn.Location{loc}), nil + case reflect.Int, reflect.Int32, reflect.Int64: + return dyn.NewValue(node.(int64), []dyn.Location{loc}), nil + case reflect.Uint, reflect.Uint32, reflect.Uint64: + return dyn.NewValue(node.(uint64), []dyn.Location{loc}), nil + default: + return dyn.InvalidValue, errorf(loc, "unknown scalar type: %v", reflect.TypeOf(node).Kind()) + } +} + +func (d *loader) loadSequence(node []interface{}, loc dyn.Location) (dyn.Value, error) { + dst := make([]dyn.Value, len(node)) + for i, value := range node { + v, err := d.load(value, loc) + if err != nil { + return dyn.InvalidValue, err + } + dst[i] = v + } + return dyn.NewValue(dst, []dyn.Location{loc}), nil +} + +func (d *loader) loadMapping(node map[string]interface{}, loc dyn.Location) (dyn.Value, error) { + dst := make(map[string]dyn.Value) + index := 0 + for key, value := range node { + index += 1 + v, err := d.load(value, dyn.Location{ + Line: loc.Line + index, + Column: loc.Column, + }) + if err != nil { + return dyn.InvalidValue, err + } + dst[key] = v + } + return dyn.NewValue(dst, []dyn.Location{loc}), nil +} diff --git a/libs/flags/json_flag.go b/libs/flags/json_flag.go index 8dbc3b2d9..3330366ca 100644 --- a/libs/flags/json_flag.go +++ b/libs/flags/json_flag.go @@ -1,9 +1,11 @@ package flags import ( - "encoding/json" "fmt" "os" + + "github.com/databricks/cli/libs/dyn/convert" + "github.com/databricks/cli/libs/dyn/jsonloader" ) type JsonFlag struct { @@ -33,7 +35,27 @@ func (j *JsonFlag) Unmarshal(v any) error { if j.raw == nil { return nil } - return json.Unmarshal(j.raw, v) + + dv, err := jsonloader.LoadJSON(j.raw) + if err != nil { + return err + } + + err = convert.ToTyped(v, dv) + if err != nil { + return err + } + + _, diags := convert.Normalize(v, dv) + if len(diags) > 0 { + summary := "" + for _, diag := range diags { + summary += fmt.Sprintf("- %s\n", diag.Summary) + } + return fmt.Errorf("json input error:\n%v", summary) + } + + return nil } func (j *JsonFlag) Type() string { diff --git a/libs/flags/json_flag_test.go b/libs/flags/json_flag_test.go index 2a8170fe6..e5030351d 100644 --- a/libs/flags/json_flag_test.go +++ b/libs/flags/json_flag_test.go @@ -6,6 +6,7 @@ import ( "path" "testing" + "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -52,7 +53,7 @@ func TestJsonFlagFile(t *testing.T) { var request any var fpath string - var payload = []byte(`"hello world"`) + var payload = []byte(`{"hello": "world"}`) { f, err := os.Create(path.Join(t.TempDir(), "file")) @@ -68,5 +69,101 @@ func TestJsonFlagFile(t *testing.T) { err = body.Unmarshal(&request) require.NoError(t, err) - assert.Equal(t, "hello world", request) + assert.Equal(t, map[string]interface{}{"hello": "world"}, request) +} + +const jsonData = ` +{ + "job_id": 123, + "new_settings": { + "name": "new job", + "email_notifications": { + "on_start": [], + "on_success": [], + "on_failure": [] + }, + "notification_settings": { + "no_alert_for_skipped_runs": true, + "no_alert_for_canceled_runs": true + }, + "timeout_seconds": 0, + "max_concurrent_runs": 1, + "tasks": [ + { + "task_key": "new task", + "email_notifications": {}, + "notification_settings": {}, + "timeout_seconds": 0, + "max_retries": 0, + "min_retry_interval_millis": 0, + "retry_on_timeout": "true" + } + ] + } +} +` + +func TestJsonUnmarshalForRequest(t *testing.T) { + var body JsonFlag + + var r jobs.ResetJob + err := body.Set(jsonData) + require.NoError(t, err) + + err = body.Unmarshal(&r) + require.NoError(t, err) + + assert.Equal(t, int64(123), r.JobId) + assert.Equal(t, "new job", r.NewSettings.Name) + assert.Equal(t, 0, r.NewSettings.TimeoutSeconds) + assert.Equal(t, 1, r.NewSettings.MaxConcurrentRuns) + assert.Equal(t, 1, len(r.NewSettings.Tasks)) + assert.Equal(t, "new task", r.NewSettings.Tasks[0].TaskKey) + assert.Equal(t, 0, r.NewSettings.Tasks[0].TimeoutSeconds) + assert.Equal(t, 0, r.NewSettings.Tasks[0].MaxRetries) + assert.Equal(t, 0, r.NewSettings.Tasks[0].MinRetryIntervalMillis) + assert.Equal(t, true, r.NewSettings.Tasks[0].RetryOnTimeout) +} + +const incorrectJsonData = ` +{ + "job_id": 123, + "settings": { + "name": "new job", + "email_notifications": { + "on_start": [], + "on_success": [], + "on_failure": [] + }, + "notification_settings": { + "no_alert_for_skipped_runs": true, + "no_alert_for_canceled_runs": true + }, + "timeout_seconds": {}, + "max_concurrent_runs": {}, + "tasks": [ + { + "task_key": "new task", + "email_notifications": {}, + "notification_settings": {}, + "timeout_seconds": 0, + "max_retries": 0, + "min_retry_interval_millis": 0, + "retry_on_timeout": "true" + } + ] + } +} +` + +func TestJsonUnmarshalRequestMismatch(t *testing.T) { + var body JsonFlag + + var r jobs.ResetJob + err := body.Set(incorrectJsonData) + require.NoError(t, err) + + err = body.Unmarshal(&r) + require.ErrorContains(t, err, `json input error: +- unknown field: settings`) }