diff --git a/bundle/python/transform.go b/bundle/python/transform.go index f6207a59..a3fea2e8 100644 --- a/bundle/python/transform.go +++ b/bundle/python/transform.go @@ -31,8 +31,21 @@ except ImportError: # for Python<3.8 from contextlib import redirect_stdout import io import sys +import json + +params = [] +try: + python_params = dbutils.widgets.get("__python_params") + if python_params: + params = json.loads(python_params) +except Exception as e: + print(e) + sys.argv = [{{.Params}}] +if params: + sys.argv = [sys.argv[0]] + params + entry = [ep for ep in metadata.distribution("{{.Task.PackageName}}").entry_points if ep.name == "{{.Task.EntryPoint}}"] f = io.StringIO() diff --git a/bundle/run/job.go b/bundle/run/job.go index b94e8fef..a6343b97 100644 --- a/bundle/run/job.go +++ b/bundle/run/job.go @@ -2,6 +2,7 @@ package run import ( "context" + "encoding/json" "fmt" "strconv" "time" @@ -221,6 +222,11 @@ func (r *jobRunner) Run(ctx context.Context, opts *Options) (output.RunOutput, e runId := new(int64) + err = r.convertPythonParams(opts) + if err != nil { + return nil, err + } + // construct request payload from cmd line flags args req, err := opts.Job.toPayload(jobID) if err != nil { @@ -299,3 +305,42 @@ func (r *jobRunner) Run(ctx context.Context, opts *Options) (output.RunOutput, e return nil, err } + +func (r *jobRunner) convertPythonParams(opts *Options) error { + if r.bundle.Config.Experimental != nil && !r.bundle.Config.Experimental.PythonWheelWrapper { + return nil + } + + needConvert := false + for _, task := range r.job.Tasks { + if task.PythonWheelTask != nil { + needConvert = true + break + } + } + + if !needConvert { + return nil + } + + if len(opts.Job.pythonParams) == 0 { + return nil + } + + if opts.Job.notebookParams == nil { + opts.Job.notebookParams = make(map[string]string) + } + + if len(opts.Job.pythonParams) > 0 { + if _, ok := opts.Job.notebookParams["__python_params"]; ok { + return fmt.Errorf("can't use __python_params as notebook param, the name is reserved for internal use") + } + p, err := json.Marshal(opts.Job.pythonParams) + if err != nil { + return err + } + opts.Job.notebookParams["__python_params"] = string(p) + } + + return nil +} diff --git a/bundle/run/job_test.go b/bundle/run/job_test.go new file mode 100644 index 00000000..e4cb4e7e --- /dev/null +++ b/bundle/run/job_test.go @@ -0,0 +1,49 @@ +package run + +import ( + "testing" + + "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/config/resources" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/require" +) + +func TestConvertPythonParams(t *testing.T) { + job := &resources.Job{ + JobSettings: &jobs.JobSettings{ + Tasks: []jobs.Task{ + {PythonWheelTask: &jobs.PythonWheelTask{ + PackageName: "my_test_code", + EntryPoint: "run", + }}, + }, + }, + } + b := &bundle.Bundle{ + Config: config.Root{ + Resources: config.Resources{ + Jobs: map[string]*resources.Job{ + "test_job": job, + }, + }, + }, + } + runner := jobRunner{key: "test", bundle: b, job: job} + + opts := &Options{ + Job: JobOptions{}, + } + runner.convertPythonParams(opts) + require.NotContains(t, opts.Job.notebookParams, "__python_params") + + opts = &Options{ + Job: JobOptions{ + pythonParams: []string{"param1", "param2", "param3"}, + }, + } + runner.convertPythonParams(opts) + require.Contains(t, opts.Job.notebookParams, "__python_params") + require.Equal(t, opts.Job.notebookParams["__python_params"], `["param1","param2","param3"]`) +} diff --git a/internal/bundle/helpers.go b/internal/bundle/helpers.go index 3fd4eabc..681edc2d 100644 --- a/internal/bundle/helpers.go +++ b/internal/bundle/helpers.go @@ -62,6 +62,18 @@ func runResource(t *testing.T, path string, key string) (string, error) { return stdout.String(), err } +func runResourceWithParams(t *testing.T, path string, key string, params ...string) (string, error) { + ctx := context.Background() + ctx = cmdio.NewContext(ctx, cmdio.Default()) + + args := make([]string, 0) + args = append(args, "bundle", "run", key) + args = append(args, params...) + c := internal.NewCobraTestRunnerWithContext(t, ctx, args...) + stdout, _, err := c.Run() + return stdout.String(), err +} + func destroyBundle(t *testing.T, path string) error { t.Setenv("BUNDLE_ROOT", path) c := internal.NewCobraTestRunner(t, "bundle", "destroy", "--auto-approve") diff --git a/internal/bundle/python_wheel_test.go b/internal/bundle/python_wheel_test.go index bfc2d8b2..c94ed93a 100644 --- a/internal/bundle/python_wheel_test.go +++ b/internal/bundle/python_wheel_test.go @@ -41,6 +41,12 @@ func runPythonWheelTest(t *testing.T, sparkVersion string, pythonWheelWrapper bo require.Contains(t, out, "Hello from my func") require.Contains(t, out, "Got arguments:") require.Contains(t, out, "['my_test_code', 'one', 'two']") + + out, err = runResourceWithParams(t, bundleRoot, "some_other_job", "--python-params=param1,param2") + require.NoError(t, err) + require.Contains(t, out, "Hello from my func") + require.Contains(t, out, "Got arguments:") + require.Contains(t, out, "['my_test_code', 'param1', 'param2']") } func TestAccPythonWheelTaskDeployAndRunWithoutWrapper(t *testing.T) {