mirror of https://github.com/databricks/cli.git
Added JSON input validation for CLI commands
This commit is contained in:
parent
f2dee890b8
commit
58ab2f2cfe
|
@ -221,10 +221,10 @@ func toTypedBool(dst reflect.Value, src dyn.Value) error {
|
|||
case dyn.KindString:
|
||||
// See https://github.com/go-yaml/yaml/blob/f6f7691b1fdeb513f56608cd2c32c51f8194bf51/decode.go#L684-L693.
|
||||
switch src.MustString() {
|
||||
case "y", "Y", "yes", "Yes", "YES", "on", "On", "ON":
|
||||
case "y", "Y", "yes", "Yes", "YES", "on", "On", "ON", "true":
|
||||
dst.SetBool(true)
|
||||
return nil
|
||||
case "n", "N", "no", "No", "NO", "off", "Off", "OFF":
|
||||
case "n", "N", "no", "No", "NO", "off", "Off", "OFF", "false":
|
||||
dst.SetBool(false)
|
||||
return nil
|
||||
}
|
||||
|
@ -246,6 +246,17 @@ func toTypedInt(dst reflect.Value, src dyn.Value) error {
|
|||
case dyn.KindInt:
|
||||
dst.SetInt(src.MustInt())
|
||||
return nil
|
||||
case dyn.KindFloat:
|
||||
v := src.MustFloat()
|
||||
if canConvertToInt(v) {
|
||||
dst.SetInt(int64(src.MustFloat()))
|
||||
return nil
|
||||
}
|
||||
|
||||
return TypeError{
|
||||
value: src,
|
||||
msg: fmt.Sprintf("expected an int, found a %s", src.Kind()),
|
||||
}
|
||||
case dyn.KindString:
|
||||
if i64, err := strconv.ParseInt(src.MustString(), 10, 64); err == nil {
|
||||
dst.SetInt(i64)
|
||||
|
@ -264,6 +275,10 @@ func toTypedInt(dst reflect.Value, src dyn.Value) error {
|
|||
}
|
||||
}
|
||||
|
||||
func canConvertToInt(v float64) bool {
|
||||
return v == float64(int(v))
|
||||
}
|
||||
|
||||
func toTypedFloat(dst reflect.Value, src dyn.Value) error {
|
||||
switch src.Kind() {
|
||||
case dyn.KindFloat:
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
package jsonloader
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/databricks/cli/libs/dyn"
|
||||
)
|
||||
|
||||
func LoadJSON(data []byte) (dyn.Value, error) {
|
||||
var root map[string]interface{}
|
||||
err := json.Unmarshal(data, &root)
|
||||
if err != nil {
|
||||
return dyn.InvalidValue, err
|
||||
}
|
||||
|
||||
loc := dyn.Location{
|
||||
Line: 1,
|
||||
Column: 1,
|
||||
}
|
||||
return newLoader().load(&root, loc)
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
package jsonloader
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/databricks/cli/libs/dyn/convert"
|
||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const jsonData = `
|
||||
{
|
||||
"job_id": 123,
|
||||
"new_settings": {
|
||||
"name": "xxx",
|
||||
"email_notifications": {
|
||||
"on_start": [],
|
||||
"on_success": [],
|
||||
"on_failure": []
|
||||
},
|
||||
"webhook_notifications": {
|
||||
"on_start": [],
|
||||
"on_failure": []
|
||||
},
|
||||
"notification_settings": {
|
||||
"no_alert_for_skipped_runs": true,
|
||||
"no_alert_for_canceled_runs": true
|
||||
},
|
||||
"timeout_seconds": 0,
|
||||
"max_concurrent_runs": 1,
|
||||
"tasks": [
|
||||
{
|
||||
"task_key": "xxx",
|
||||
"email_notifications": {},
|
||||
"notification_settings": {},
|
||||
"timeout_seconds": 0,
|
||||
"max_retries": 0,
|
||||
"min_retry_interval_millis": 0,
|
||||
"retry_on_timeout": "true"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
func TestJsonLoader(t *testing.T) {
|
||||
v, err := LoadJSON([]byte(jsonData))
|
||||
require.NoError(t, err)
|
||||
|
||||
var r jobs.ResetJob
|
||||
err = convert.ToTyped(&r, v)
|
||||
require.NoError(t, err)
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
package jsonloader
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/databricks/cli/libs/dyn"
|
||||
)
|
||||
|
||||
type loader struct {
|
||||
}
|
||||
|
||||
func newLoader() *loader {
|
||||
return &loader{}
|
||||
}
|
||||
|
||||
func errorf(loc dyn.Location, format string, args ...interface{}) error {
|
||||
return fmt.Errorf("json (%s): %s", loc, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (d *loader) load(node any, loc dyn.Location) (dyn.Value, error) {
|
||||
var value dyn.Value
|
||||
var err error
|
||||
|
||||
if node == nil {
|
||||
return dyn.NilValue, nil
|
||||
}
|
||||
|
||||
if reflect.TypeOf(node).Kind() == reflect.Ptr {
|
||||
return d.load(reflect.ValueOf(node).Elem().Interface(), loc)
|
||||
}
|
||||
|
||||
switch reflect.TypeOf(node).Kind() {
|
||||
case reflect.Map:
|
||||
value, err = d.loadMapping(node.(map[string]interface{}), loc)
|
||||
case reflect.Slice:
|
||||
value, err = d.loadSequence(node.([]interface{}), loc)
|
||||
case reflect.String, reflect.Bool,
|
||||
reflect.Float64, reflect.Float32,
|
||||
reflect.Int, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint32, reflect.Uint64:
|
||||
value, err = d.loadScalar(node, loc)
|
||||
|
||||
default:
|
||||
return dyn.InvalidValue, errorf(loc, "unknown node kind: %v", reflect.TypeOf(node).Kind())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return dyn.InvalidValue, err
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (d *loader) loadScalar(node any, loc dyn.Location) (dyn.Value, error) {
|
||||
switch reflect.TypeOf(node).Kind() {
|
||||
case reflect.String:
|
||||
return dyn.NewValue(node.(string), []dyn.Location{loc}), nil
|
||||
case reflect.Bool:
|
||||
return dyn.NewValue(node.(bool), []dyn.Location{loc}), nil
|
||||
case reflect.Float64, reflect.Float32:
|
||||
return dyn.NewValue(node.(float64), []dyn.Location{loc}), nil
|
||||
case reflect.Int, reflect.Int32, reflect.Int64:
|
||||
return dyn.NewValue(node.(int64), []dyn.Location{loc}), nil
|
||||
case reflect.Uint, reflect.Uint32, reflect.Uint64:
|
||||
return dyn.NewValue(node.(uint64), []dyn.Location{loc}), nil
|
||||
default:
|
||||
return dyn.InvalidValue, errorf(loc, "unknown scalar type: %v", reflect.TypeOf(node).Kind())
|
||||
}
|
||||
}
|
||||
|
||||
func (d *loader) loadSequence(node []interface{}, loc dyn.Location) (dyn.Value, error) {
|
||||
dst := make([]dyn.Value, len(node))
|
||||
for i, value := range node {
|
||||
v, err := d.load(value, loc)
|
||||
if err != nil {
|
||||
return dyn.InvalidValue, err
|
||||
}
|
||||
dst[i] = v
|
||||
}
|
||||
return dyn.NewValue(dst, []dyn.Location{loc}), nil
|
||||
}
|
||||
|
||||
func (d *loader) loadMapping(node map[string]interface{}, loc dyn.Location) (dyn.Value, error) {
|
||||
dst := make(map[string]dyn.Value)
|
||||
index := 0
|
||||
for key, value := range node {
|
||||
index += 1
|
||||
v, err := d.load(value, dyn.Location{
|
||||
Line: loc.Line + index,
|
||||
Column: loc.Column,
|
||||
})
|
||||
if err != nil {
|
||||
return dyn.InvalidValue, err
|
||||
}
|
||||
dst[key] = v
|
||||
}
|
||||
return dyn.NewValue(dst, []dyn.Location{loc}), nil
|
||||
}
|
|
@ -1,9 +1,11 @@
|
|||
package flags
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/databricks/cli/libs/dyn/convert"
|
||||
"github.com/databricks/cli/libs/dyn/jsonloader"
|
||||
)
|
||||
|
||||
type JsonFlag struct {
|
||||
|
@ -33,7 +35,27 @@ func (j *JsonFlag) Unmarshal(v any) error {
|
|||
if j.raw == nil {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(j.raw, v)
|
||||
|
||||
dv, err := jsonloader.LoadJSON(j.raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = convert.ToTyped(v, dv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, diags := convert.Normalize(v, dv)
|
||||
if len(diags) > 0 {
|
||||
summary := ""
|
||||
for _, diag := range diags {
|
||||
summary += fmt.Sprintf("- %s\n", diag.Summary)
|
||||
}
|
||||
return fmt.Errorf("json input error:\n%v", summary)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (j *JsonFlag) Type() string {
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -52,7 +53,7 @@ func TestJsonFlagFile(t *testing.T) {
|
|||
var request any
|
||||
|
||||
var fpath string
|
||||
var payload = []byte(`"hello world"`)
|
||||
var payload = []byte(`{"hello": "world"}`)
|
||||
|
||||
{
|
||||
f, err := os.Create(path.Join(t.TempDir(), "file"))
|
||||
|
@ -68,5 +69,101 @@ func TestJsonFlagFile(t *testing.T) {
|
|||
err = body.Unmarshal(&request)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "hello world", request)
|
||||
assert.Equal(t, map[string]interface{}{"hello": "world"}, request)
|
||||
}
|
||||
|
||||
const jsonData = `
|
||||
{
|
||||
"job_id": 123,
|
||||
"new_settings": {
|
||||
"name": "new job",
|
||||
"email_notifications": {
|
||||
"on_start": [],
|
||||
"on_success": [],
|
||||
"on_failure": []
|
||||
},
|
||||
"notification_settings": {
|
||||
"no_alert_for_skipped_runs": true,
|
||||
"no_alert_for_canceled_runs": true
|
||||
},
|
||||
"timeout_seconds": 0,
|
||||
"max_concurrent_runs": 1,
|
||||
"tasks": [
|
||||
{
|
||||
"task_key": "new task",
|
||||
"email_notifications": {},
|
||||
"notification_settings": {},
|
||||
"timeout_seconds": 0,
|
||||
"max_retries": 0,
|
||||
"min_retry_interval_millis": 0,
|
||||
"retry_on_timeout": "true"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
func TestJsonUnmarshalForRequest(t *testing.T) {
|
||||
var body JsonFlag
|
||||
|
||||
var r jobs.ResetJob
|
||||
err := body.Set(jsonData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = body.Unmarshal(&r)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, int64(123), r.JobId)
|
||||
assert.Equal(t, "new job", r.NewSettings.Name)
|
||||
assert.Equal(t, 0, r.NewSettings.TimeoutSeconds)
|
||||
assert.Equal(t, 1, r.NewSettings.MaxConcurrentRuns)
|
||||
assert.Equal(t, 1, len(r.NewSettings.Tasks))
|
||||
assert.Equal(t, "new task", r.NewSettings.Tasks[0].TaskKey)
|
||||
assert.Equal(t, 0, r.NewSettings.Tasks[0].TimeoutSeconds)
|
||||
assert.Equal(t, 0, r.NewSettings.Tasks[0].MaxRetries)
|
||||
assert.Equal(t, 0, r.NewSettings.Tasks[0].MinRetryIntervalMillis)
|
||||
assert.Equal(t, true, r.NewSettings.Tasks[0].RetryOnTimeout)
|
||||
}
|
||||
|
||||
const incorrectJsonData = `
|
||||
{
|
||||
"job_id": 123,
|
||||
"settings": {
|
||||
"name": "new job",
|
||||
"email_notifications": {
|
||||
"on_start": [],
|
||||
"on_success": [],
|
||||
"on_failure": []
|
||||
},
|
||||
"notification_settings": {
|
||||
"no_alert_for_skipped_runs": true,
|
||||
"no_alert_for_canceled_runs": true
|
||||
},
|
||||
"timeout_seconds": {},
|
||||
"max_concurrent_runs": {},
|
||||
"tasks": [
|
||||
{
|
||||
"task_key": "new task",
|
||||
"email_notifications": {},
|
||||
"notification_settings": {},
|
||||
"timeout_seconds": 0,
|
||||
"max_retries": 0,
|
||||
"min_retry_interval_millis": 0,
|
||||
"retry_on_timeout": "true"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
func TestJsonUnmarshalRequestMismatch(t *testing.T) {
|
||||
var body JsonFlag
|
||||
|
||||
var r jobs.ResetJob
|
||||
err := body.Set(incorrectJsonData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = body.Unmarshal(&r)
|
||||
require.ErrorContains(t, err, `json input error:
|
||||
- unknown field: settings`)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue