Added JSON input validation for CLI commands

This commit is contained in:
Andrew Nester 2024-09-13 13:27:53 +02:00
parent f2dee890b8
commit 58ab2f2cfe
No known key found for this signature in database
GPG Key ID: 12BC628A44B7DA57
6 changed files with 313 additions and 6 deletions

View File

@ -221,10 +221,10 @@ func toTypedBool(dst reflect.Value, src dyn.Value) error {
case dyn.KindString: case dyn.KindString:
// See https://github.com/go-yaml/yaml/blob/f6f7691b1fdeb513f56608cd2c32c51f8194bf51/decode.go#L684-L693. // See https://github.com/go-yaml/yaml/blob/f6f7691b1fdeb513f56608cd2c32c51f8194bf51/decode.go#L684-L693.
switch src.MustString() { switch src.MustString() {
case "y", "Y", "yes", "Yes", "YES", "on", "On", "ON": case "y", "Y", "yes", "Yes", "YES", "on", "On", "ON", "true":
dst.SetBool(true) dst.SetBool(true)
return nil return nil
case "n", "N", "no", "No", "NO", "off", "Off", "OFF": case "n", "N", "no", "No", "NO", "off", "Off", "OFF", "false":
dst.SetBool(false) dst.SetBool(false)
return nil return nil
} }
@ -246,6 +246,17 @@ func toTypedInt(dst reflect.Value, src dyn.Value) error {
case dyn.KindInt: case dyn.KindInt:
dst.SetInt(src.MustInt()) dst.SetInt(src.MustInt())
return nil return nil
case dyn.KindFloat:
v := src.MustFloat()
if canConvertToInt(v) {
dst.SetInt(int64(src.MustFloat()))
return nil
}
return TypeError{
value: src,
msg: fmt.Sprintf("expected an int, found a %s", src.Kind()),
}
case dyn.KindString: case dyn.KindString:
if i64, err := strconv.ParseInt(src.MustString(), 10, 64); err == nil { if i64, err := strconv.ParseInt(src.MustString(), 10, 64); err == nil {
dst.SetInt(i64) dst.SetInt(i64)
@ -264,6 +275,10 @@ func toTypedInt(dst reflect.Value, src dyn.Value) error {
} }
} }
func canConvertToInt(v float64) bool {
return v == float64(int(v))
}
func toTypedFloat(dst reflect.Value, src dyn.Value) error { func toTypedFloat(dst reflect.Value, src dyn.Value) error {
switch src.Kind() { switch src.Kind() {
case dyn.KindFloat: case dyn.KindFloat:

View File

@ -0,0 +1,21 @@
package jsonloader
import (
"encoding/json"
"github.com/databricks/cli/libs/dyn"
)
func LoadJSON(data []byte) (dyn.Value, error) {
var root map[string]interface{}
err := json.Unmarshal(data, &root)
if err != nil {
return dyn.InvalidValue, err
}
loc := dyn.Location{
Line: 1,
Column: 1,
}
return newLoader().load(&root, loc)
}

View File

@ -0,0 +1,53 @@
package jsonloader
import (
"testing"
"github.com/databricks/cli/libs/dyn/convert"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/require"
)
const jsonData = `
{
"job_id": 123,
"new_settings": {
"name": "xxx",
"email_notifications": {
"on_start": [],
"on_success": [],
"on_failure": []
},
"webhook_notifications": {
"on_start": [],
"on_failure": []
},
"notification_settings": {
"no_alert_for_skipped_runs": true,
"no_alert_for_canceled_runs": true
},
"timeout_seconds": 0,
"max_concurrent_runs": 1,
"tasks": [
{
"task_key": "xxx",
"email_notifications": {},
"notification_settings": {},
"timeout_seconds": 0,
"max_retries": 0,
"min_retry_interval_millis": 0,
"retry_on_timeout": "true"
}
]
}
}
`
func TestJsonLoader(t *testing.T) {
v, err := LoadJSON([]byte(jsonData))
require.NoError(t, err)
var r jobs.ResetJob
err = convert.ToTyped(&r, v)
require.NoError(t, err)
}

View File

@ -0,0 +1,99 @@
package jsonloader
import (
"fmt"
"reflect"
"github.com/databricks/cli/libs/dyn"
)
type loader struct {
}
func newLoader() *loader {
return &loader{}
}
func errorf(loc dyn.Location, format string, args ...interface{}) error {
return fmt.Errorf("json (%s): %s", loc, fmt.Sprintf(format, args...))
}
func (d *loader) load(node any, loc dyn.Location) (dyn.Value, error) {
var value dyn.Value
var err error
if node == nil {
return dyn.NilValue, nil
}
if reflect.TypeOf(node).Kind() == reflect.Ptr {
return d.load(reflect.ValueOf(node).Elem().Interface(), loc)
}
switch reflect.TypeOf(node).Kind() {
case reflect.Map:
value, err = d.loadMapping(node.(map[string]interface{}), loc)
case reflect.Slice:
value, err = d.loadSequence(node.([]interface{}), loc)
case reflect.String, reflect.Bool,
reflect.Float64, reflect.Float32,
reflect.Int, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint32, reflect.Uint64:
value, err = d.loadScalar(node, loc)
default:
return dyn.InvalidValue, errorf(loc, "unknown node kind: %v", reflect.TypeOf(node).Kind())
}
if err != nil {
return dyn.InvalidValue, err
}
return value, nil
}
func (d *loader) loadScalar(node any, loc dyn.Location) (dyn.Value, error) {
switch reflect.TypeOf(node).Kind() {
case reflect.String:
return dyn.NewValue(node.(string), []dyn.Location{loc}), nil
case reflect.Bool:
return dyn.NewValue(node.(bool), []dyn.Location{loc}), nil
case reflect.Float64, reflect.Float32:
return dyn.NewValue(node.(float64), []dyn.Location{loc}), nil
case reflect.Int, reflect.Int32, reflect.Int64:
return dyn.NewValue(node.(int64), []dyn.Location{loc}), nil
case reflect.Uint, reflect.Uint32, reflect.Uint64:
return dyn.NewValue(node.(uint64), []dyn.Location{loc}), nil
default:
return dyn.InvalidValue, errorf(loc, "unknown scalar type: %v", reflect.TypeOf(node).Kind())
}
}
func (d *loader) loadSequence(node []interface{}, loc dyn.Location) (dyn.Value, error) {
dst := make([]dyn.Value, len(node))
for i, value := range node {
v, err := d.load(value, loc)
if err != nil {
return dyn.InvalidValue, err
}
dst[i] = v
}
return dyn.NewValue(dst, []dyn.Location{loc}), nil
}
func (d *loader) loadMapping(node map[string]interface{}, loc dyn.Location) (dyn.Value, error) {
dst := make(map[string]dyn.Value)
index := 0
for key, value := range node {
index += 1
v, err := d.load(value, dyn.Location{
Line: loc.Line + index,
Column: loc.Column,
})
if err != nil {
return dyn.InvalidValue, err
}
dst[key] = v
}
return dyn.NewValue(dst, []dyn.Location{loc}), nil
}

View File

@ -1,9 +1,11 @@
package flags package flags
import ( import (
"encoding/json"
"fmt" "fmt"
"os" "os"
"github.com/databricks/cli/libs/dyn/convert"
"github.com/databricks/cli/libs/dyn/jsonloader"
) )
type JsonFlag struct { type JsonFlag struct {
@ -33,7 +35,27 @@ func (j *JsonFlag) Unmarshal(v any) error {
if j.raw == nil { if j.raw == nil {
return nil return nil
} }
return json.Unmarshal(j.raw, v)
dv, err := jsonloader.LoadJSON(j.raw)
if err != nil {
return err
}
err = convert.ToTyped(v, dv)
if err != nil {
return err
}
_, diags := convert.Normalize(v, dv)
if len(diags) > 0 {
summary := ""
for _, diag := range diags {
summary += fmt.Sprintf("- %s\n", diag.Summary)
}
return fmt.Errorf("json input error:\n%v", summary)
}
return nil
} }
func (j *JsonFlag) Type() string { func (j *JsonFlag) Type() string {

View File

@ -6,6 +6,7 @@ import (
"path" "path"
"testing" "testing"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -52,7 +53,7 @@ func TestJsonFlagFile(t *testing.T) {
var request any var request any
var fpath string var fpath string
var payload = []byte(`"hello world"`) var payload = []byte(`{"hello": "world"}`)
{ {
f, err := os.Create(path.Join(t.TempDir(), "file")) f, err := os.Create(path.Join(t.TempDir(), "file"))
@ -68,5 +69,101 @@ func TestJsonFlagFile(t *testing.T) {
err = body.Unmarshal(&request) err = body.Unmarshal(&request)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "hello world", request) assert.Equal(t, map[string]interface{}{"hello": "world"}, request)
}
const jsonData = `
{
"job_id": 123,
"new_settings": {
"name": "new job",
"email_notifications": {
"on_start": [],
"on_success": [],
"on_failure": []
},
"notification_settings": {
"no_alert_for_skipped_runs": true,
"no_alert_for_canceled_runs": true
},
"timeout_seconds": 0,
"max_concurrent_runs": 1,
"tasks": [
{
"task_key": "new task",
"email_notifications": {},
"notification_settings": {},
"timeout_seconds": 0,
"max_retries": 0,
"min_retry_interval_millis": 0,
"retry_on_timeout": "true"
}
]
}
}
`
func TestJsonUnmarshalForRequest(t *testing.T) {
var body JsonFlag
var r jobs.ResetJob
err := body.Set(jsonData)
require.NoError(t, err)
err = body.Unmarshal(&r)
require.NoError(t, err)
assert.Equal(t, int64(123), r.JobId)
assert.Equal(t, "new job", r.NewSettings.Name)
assert.Equal(t, 0, r.NewSettings.TimeoutSeconds)
assert.Equal(t, 1, r.NewSettings.MaxConcurrentRuns)
assert.Equal(t, 1, len(r.NewSettings.Tasks))
assert.Equal(t, "new task", r.NewSettings.Tasks[0].TaskKey)
assert.Equal(t, 0, r.NewSettings.Tasks[0].TimeoutSeconds)
assert.Equal(t, 0, r.NewSettings.Tasks[0].MaxRetries)
assert.Equal(t, 0, r.NewSettings.Tasks[0].MinRetryIntervalMillis)
assert.Equal(t, true, r.NewSettings.Tasks[0].RetryOnTimeout)
}
const incorrectJsonData = `
{
"job_id": 123,
"settings": {
"name": "new job",
"email_notifications": {
"on_start": [],
"on_success": [],
"on_failure": []
},
"notification_settings": {
"no_alert_for_skipped_runs": true,
"no_alert_for_canceled_runs": true
},
"timeout_seconds": {},
"max_concurrent_runs": {},
"tasks": [
{
"task_key": "new task",
"email_notifications": {},
"notification_settings": {},
"timeout_seconds": 0,
"max_retries": 0,
"min_retry_interval_millis": 0,
"retry_on_timeout": "true"
}
]
}
}
`
func TestJsonUnmarshalRequestMismatch(t *testing.T) {
var body JsonFlag
var r jobs.ResetJob
err := body.Set(incorrectJsonData)
require.NoError(t, err)
err = body.Unmarshal(&r)
require.ErrorContains(t, err, `json input error:
- unknown field: settings`)
} }