From 70fe0e36efa1666530381440027f66057b409c29 Mon Sep 17 00:00:00 2001 From: Andrew Nester Date: Wed, 17 Jan 2024 14:26:33 +0000 Subject: [PATCH] Added `databricks bundle generate job` command (#1043) ## Changes Now it's possible to generate bundle configuration for existing job. For now it only supports jobs with notebook tasks. It will download notebooks referenced in the job tasks and generate bundle YAML config for this job which can be included in larger bundle. ## Tests Running command manually Example of generated config ``` resources: jobs: job_128737545467921: name: Notebook job format: MULTI_TASK tasks: - task_key: as_notebook existing_cluster_id: 0704-xxxxxx-yyyyyyy notebook_task: base_parameters: bundle_root: /Users/andrew.nester@databricks.com/.bundle/job_with_module_imports/development/files notebook_path: ./entry_notebook.py source: WORKSPACE run_if: ALL_SUCCESS max_concurrent_runs: 1 ``` ## Tests Manual (on our last 100 jobs) + added end-to-end test ``` --- PASS: TestAccGenerateFromExistingJobAndDeploy (50.91s) PASS coverage: 61.5% of statements in ./... ok github.com/databricks/cli/internal/bundle 51.209s coverage: 61.5% of statements in ./... ``` --- bundle/config/generate/job.go | 34 +++ .../config/mutator/populate_current_user.go | 11 +- cmd/bundle/bundle.go | 1 + cmd/bundle/generate.go | 18 ++ cmd/bundle/generate/job.go | 91 ++++++++ cmd/bundle/generate/utils.go | 107 ++++++++++ cmd/workspace/workspace/export_dir.go | 16 +- .../databricks_template_schema.json | 8 + .../template/databricks.yml.tmpl | 8 + internal/bundle/generate_job_test.go | 124 +++++++++++ libs/dyn/value.go | 4 + libs/dyn/yamlsaver/order.go | 33 +++ libs/dyn/yamlsaver/order_test.go | 24 +++ libs/dyn/yamlsaver/saver.go | 139 +++++++++++++ libs/dyn/yamlsaver/saver_test.go | 195 ++++++++++++++++++ libs/dyn/yamlsaver/utils.go | 49 +++++ libs/dyn/yamlsaver/utils_test.go | 48 +++++ libs/notebook/ext.go | 23 +++ libs/textutil/textutil.go | 20 ++ libs/textutil/textutil_test.go | 54 +++++ 20 files changed, 984 insertions(+), 23 deletions(-) create mode 100644 bundle/config/generate/job.go create mode 100644 cmd/bundle/generate.go create mode 100644 cmd/bundle/generate/job.go create mode 100644 cmd/bundle/generate/utils.go create mode 100644 internal/bundle/bundles/with_includes/databricks_template_schema.json create mode 100644 internal/bundle/bundles/with_includes/template/databricks.yml.tmpl create mode 100644 internal/bundle/generate_job_test.go create mode 100644 libs/dyn/yamlsaver/order.go create mode 100644 libs/dyn/yamlsaver/order_test.go create mode 100644 libs/dyn/yamlsaver/saver.go create mode 100644 libs/dyn/yamlsaver/saver_test.go create mode 100644 libs/dyn/yamlsaver/utils.go create mode 100644 libs/dyn/yamlsaver/utils_test.go create mode 100644 libs/notebook/ext.go create mode 100644 libs/textutil/textutil.go create mode 100644 libs/textutil/textutil_test.go diff --git a/bundle/config/generate/job.go b/bundle/config/generate/job.go new file mode 100644 index 00000000..469f8422 --- /dev/null +++ b/bundle/config/generate/job.go @@ -0,0 +1,34 @@ +package generate + +import ( + "github.com/databricks/cli/libs/dyn" + "github.com/databricks/cli/libs/dyn/yamlsaver" + "github.com/databricks/databricks-sdk-go/service/jobs" +) + +var jobOrder = yamlsaver.NewOrder([]string{"name", "job_clusters", "compute", "tasks"}) +var taskOrder = yamlsaver.NewOrder([]string{"task_key", "depends_on", "existing_cluster_id", "new_cluster", "job_cluster_key"}) + +func ConvertJobToValue(job *jobs.Job) (dyn.Value, error) { + value := make(map[string]dyn.Value) + + if job.Settings.Tasks != nil { + tasks := make([]dyn.Value, 0) + for _, task := range job.Settings.Tasks { + v, err := convertTaskToValue(task, taskOrder) + if err != nil { + return dyn.NilValue, err + } + tasks = append(tasks, v) + } + // We're using location lines to define the order of keys in exported YAML. + value["tasks"] = dyn.NewValue(tasks, dyn.Location{Line: jobOrder.Get("tasks")}) + } + + return yamlsaver.ConvertToMapValue(job.Settings, jobOrder, []string{"format", "new_cluster", "existing_cluster_id"}, value) +} + +func convertTaskToValue(task jobs.Task, order *yamlsaver.Order) (dyn.Value, error) { + dst := make(map[string]dyn.Value) + return yamlsaver.ConvertToMapValue(task, order, []string{"format"}, dst) +} diff --git a/bundle/config/mutator/populate_current_user.go b/bundle/config/mutator/populate_current_user.go index 5b5d3096..60587578 100644 --- a/bundle/config/mutator/populate_current_user.go +++ b/bundle/config/mutator/populate_current_user.go @@ -3,11 +3,11 @@ package mutator import ( "context" "strings" - "unicode" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" "github.com/databricks/cli/libs/tags" + "github.com/databricks/cli/libs/textutil" ) type populateCurrentUser struct{} @@ -43,17 +43,10 @@ func (m *populateCurrentUser) Apply(ctx context.Context, b *bundle.Bundle) error return nil } -func replaceNonAlphanumeric(r rune) rune { - if unicode.IsLetter(r) || unicode.IsDigit(r) { - return r - } - return '_' -} - // Get a short-form username, based on the user's primary email address. // We leave the full range of unicode letters in tact, but remove all "special" characters, // including dots, which are not supported in e.g. experiment names. func getShortUserName(emailAddress string) string { local, _, _ := strings.Cut(emailAddress, "@") - return strings.Map(replaceNonAlphanumeric, local) + return textutil.NormalizeString(local) } diff --git a/cmd/bundle/bundle.go b/cmd/bundle/bundle.go index 3206b94e..3aa6945b 100644 --- a/cmd/bundle/bundle.go +++ b/cmd/bundle/bundle.go @@ -22,5 +22,6 @@ func New() *cobra.Command { cmd.AddCommand(newTestCommand()) cmd.AddCommand(newValidateCommand()) cmd.AddCommand(newInitCommand()) + cmd.AddCommand(newGenerateCommand()) return cmd } diff --git a/cmd/bundle/generate.go b/cmd/bundle/generate.go new file mode 100644 index 00000000..a593f52f --- /dev/null +++ b/cmd/bundle/generate.go @@ -0,0 +1,18 @@ +package bundle + +import ( + "github.com/databricks/cli/cmd/bundle/generate" + "github.com/spf13/cobra" +) + +func newGenerateCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "generate", + Short: "Generate bundle configuration", + Long: "Generate bundle configuration", + PreRunE: ConfigureBundleWithVariables, + } + + cmd.AddCommand(generate.NewGenerateJobCommand()) + return cmd +} diff --git a/cmd/bundle/generate/job.go b/cmd/bundle/generate/job.go new file mode 100644 index 00000000..8e186cc3 --- /dev/null +++ b/cmd/bundle/generate/job.go @@ -0,0 +1,91 @@ +package generate + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/config/generate" + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/dyn" + "github.com/databricks/cli/libs/dyn/yamlsaver" + "github.com/databricks/cli/libs/textutil" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/spf13/cobra" +) + +func NewGenerateJobCommand() *cobra.Command { + var configDir string + var sourceDir string + var jobId int64 + var force bool + + cmd := &cobra.Command{ + Use: "job", + Short: "Generate bundle configuration for a job", + PreRunE: root.MustConfigureBundle, + } + + cmd.Flags().Int64Var(&jobId, "existing-job-id", 0, `Job ID of the job to generate config for`) + cmd.MarkFlagRequired("existing-job-id") + + wd, err := os.Getwd() + if err != nil { + wd = "." + } + + cmd.Flags().StringVarP(&configDir, "config-dir", "d", filepath.Join(wd, "resources"), `Dir path where the output config will be stored`) + cmd.Flags().StringVarP(&sourceDir, "source-dir", "s", filepath.Join(wd, "src"), `Dir path where the downloaded files will be stored`) + cmd.Flags().BoolVarP(&force, "force", "f", false, `Force overwrite existing files in the output directory`) + + cmd.RunE = func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + b := bundle.Get(ctx) + w := b.WorkspaceClient() + + job, err := w.Jobs.Get(ctx, jobs.GetJobRequest{JobId: jobId}) + if err != nil { + return err + } + + downloader := newNotebookDownloader(w, sourceDir, configDir) + for _, task := range job.Settings.Tasks { + err := downloader.MarkForDownload(ctx, &task) + if err != nil { + return err + } + } + + v, err := generate.ConvertJobToValue(job) + if err != nil { + return err + } + + jobKey := fmt.Sprintf("job_%s", textutil.NormalizeString(job.Settings.Name)) + result := map[string]dyn.Value{ + "resources": dyn.V(map[string]dyn.Value{ + "jobs": dyn.V(map[string]dyn.Value{ + jobKey: v, + }), + }), + } + + err = downloader.FlushToDisk(ctx, force) + if err != nil { + return err + } + + filename := filepath.Join(configDir, fmt.Sprintf("%s.yml", jobKey)) + err = yamlsaver.SaveAsYAML(result, filename, force) + if err != nil { + return err + } + + cmdio.LogString(ctx, fmt.Sprintf("Job configuration successfully saved to %s", filename)) + return nil + } + + return cmd +} diff --git a/cmd/bundle/generate/utils.go b/cmd/bundle/generate/utils.go new file mode 100644 index 00000000..8d450088 --- /dev/null +++ b/cmd/bundle/generate/utils.go @@ -0,0 +1,107 @@ +package generate + +import ( + "context" + "fmt" + "io" + "os" + "path" + "path/filepath" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/notebook" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/jobs" + "golang.org/x/sync/errgroup" +) + +type notebookDownloader struct { + notebooks map[string]string + w *databricks.WorkspaceClient + sourceDir string + configDir string +} + +func (n *notebookDownloader) MarkForDownload(ctx context.Context, task *jobs.Task) error { + if task.NotebookTask == nil { + return nil + } + + info, err := n.w.Workspace.GetStatusByPath(ctx, task.NotebookTask.NotebookPath) + if err != nil { + return err + } + + ext := notebook.GetExtensionByLanguage(info) + + filename := path.Base(task.NotebookTask.NotebookPath) + ext + targetPath := filepath.Join(n.sourceDir, filename) + + n.notebooks[targetPath] = task.NotebookTask.NotebookPath + + // Update the notebook path to be relative to the config dir + rel, err := filepath.Rel(n.configDir, targetPath) + if err != nil { + return err + } + + task.NotebookTask.NotebookPath = rel + return nil +} + +func (n *notebookDownloader) FlushToDisk(ctx context.Context, force bool) error { + err := os.MkdirAll(n.sourceDir, 0755) + if err != nil { + return err + } + + // First check that all files can be written + for targetPath := range n.notebooks { + info, err := os.Stat(targetPath) + if err == nil { + if info.IsDir() { + return fmt.Errorf("%s is a directory", targetPath) + } + if !force { + return fmt.Errorf("%s already exists. Use --force to overwrite", targetPath) + } + } + } + + errs, errCtx := errgroup.WithContext(ctx) + for k, v := range n.notebooks { + targetPath := k + notebookPath := v + errs.Go(func() error { + reader, err := n.w.Workspace.Download(errCtx, notebookPath) + if err != nil { + return err + } + + file, err := os.Create(targetPath) + if err != nil { + return err + } + defer file.Close() + + _, err = io.Copy(file, reader) + if err != nil { + return err + } + + cmdio.LogString(errCtx, fmt.Sprintf("Notebook successfully saved to %s", targetPath)) + return reader.Close() + }) + } + + return errs.Wait() +} + +func newNotebookDownloader(w *databricks.WorkspaceClient, sourceDir string, configDir string) *notebookDownloader { + return ¬ebookDownloader{ + notebooks: make(map[string]string), + w: w, + sourceDir: sourceDir, + configDir: configDir, + } +} diff --git a/cmd/workspace/workspace/export_dir.go b/cmd/workspace/workspace/export_dir.go index 4f50a96e..d2a86d00 100644 --- a/cmd/workspace/workspace/export_dir.go +++ b/cmd/workspace/workspace/export_dir.go @@ -11,6 +11,7 @@ import ( "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/filer" + "github.com/databricks/cli/libs/notebook" "github.com/databricks/databricks-sdk-go/service/workspace" "github.com/spf13/cobra" ) @@ -47,20 +48,7 @@ func (opts exportDirOptions) callback(ctx context.Context, workspaceFiler filer. return err } objectInfo := info.Sys().(workspace.ObjectInfo) - if objectInfo.ObjectType == workspace.ObjectTypeNotebook { - switch objectInfo.Language { - case workspace.LanguagePython: - targetPath += ".py" - case workspace.LanguageR: - targetPath += ".r" - case workspace.LanguageScala: - targetPath += ".scala" - case workspace.LanguageSql: - targetPath += ".sql" - default: - // Do not add any extension to the file name - } - } + targetPath += notebook.GetExtensionByLanguage(&objectInfo) // Skip file if a file already exists in path. // os.Stat returns a fs.ErrNotExist if a file does not exist at path. diff --git a/internal/bundle/bundles/with_includes/databricks_template_schema.json b/internal/bundle/bundles/with_includes/databricks_template_schema.json new file mode 100644 index 00000000..216bc4c1 --- /dev/null +++ b/internal/bundle/bundles/with_includes/databricks_template_schema.json @@ -0,0 +1,8 @@ +{ + "properties": { + "unique_id": { + "type": "string", + "description": "Unique ID for bundle" + } + } +} diff --git a/internal/bundle/bundles/with_includes/template/databricks.yml.tmpl b/internal/bundle/bundles/with_includes/template/databricks.yml.tmpl new file mode 100644 index 00000000..5d17e0fd --- /dev/null +++ b/internal/bundle/bundles/with_includes/template/databricks.yml.tmpl @@ -0,0 +1,8 @@ +bundle: + name: with_includes + +workspace: + root_path: "~/.bundle/{{.unique_id}}" + +includes: + - resources/*yml diff --git a/internal/bundle/generate_job_test.go b/internal/bundle/generate_job_test.go new file mode 100644 index 00000000..c4f191e7 --- /dev/null +++ b/internal/bundle/generate_job_test.go @@ -0,0 +1,124 @@ +package bundle + +import ( + "context" + "fmt" + "os" + "path" + "path/filepath" + "strings" + "testing" + + "github.com/databricks/cli/internal" + "github.com/databricks/cli/internal/testutil" + "github.com/databricks/cli/libs/filer" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/compute" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func TestAccGenerateFromExistingJobAndDeploy(t *testing.T) { + env := internal.GetEnvOrSkipTest(t, "CLOUD_ENV") + t.Log(env) + + uniqueId := uuid.New().String() + bundleRoot, err := initTestTemplate(t, "with_includes", map[string]any{ + "unique_id": uniqueId, + }) + require.NoError(t, err) + + jobId := createTestJob(t) + t.Cleanup(func() { + destroyJob(t, jobId) + require.NoError(t, err) + }) + + t.Setenv("BUNDLE_ROOT", bundleRoot) + c := internal.NewCobraTestRunner(t, "bundle", "generate", "job", + "--existing-job-id", fmt.Sprint(jobId), + "--config-dir", filepath.Join(bundleRoot, "resources"), + "--source-dir", filepath.Join(bundleRoot, "src")) + _, _, err = c.Run() + require.NoError(t, err) + + _, err = os.Stat(filepath.Join(bundleRoot, "src", "test.py")) + require.NoError(t, err) + + matches, err := filepath.Glob(filepath.Join(bundleRoot, "resources", "job_generated_job_*.yml")) + require.NoError(t, err) + require.Len(t, matches, 1) + + // check the content of generated yaml + data, err := os.ReadFile(matches[0]) + require.NoError(t, err) + generatedYaml := string(data) + require.Contains(t, generatedYaml, "notebook_task:") + require.Contains(t, generatedYaml, "notebook_path: ../src/test.py") + require.Contains(t, generatedYaml, "task_key: test") + require.Contains(t, generatedYaml, "new_cluster:") + require.Contains(t, generatedYaml, "spark_version: 13.3.x-scala2.12") + require.Contains(t, generatedYaml, "num_workers: 1") + + err = deployBundle(t, bundleRoot) + require.NoError(t, err) + + err = destroyBundle(t, bundleRoot) + require.NoError(t, err) + +} + +func createTestJob(t *testing.T) int64 { + var nodeTypeId string + switch testutil.GetCloud(t) { + case testutil.AWS: + nodeTypeId = "i3.xlarge" + case testutil.Azure: + nodeTypeId = "Standard_DS4_v2" + case testutil.GCP: + nodeTypeId = "n1-standard-4" + } + + w, err := databricks.NewWorkspaceClient() + require.NoError(t, err) + + ctx := context.Background() + tmpdir := internal.TemporaryWorkspaceDir(t, w) + f, err := filer.NewWorkspaceFilesClient(w, tmpdir) + require.NoError(t, err) + + err = f.Write(ctx, "test.py", strings.NewReader("# Databricks notebook source\nprint('Hello world!'))")) + require.NoError(t, err) + + resp, err := w.Jobs.Create(ctx, jobs.CreateJob{ + Name: internal.RandomName("generated-job-"), + Tasks: []jobs.Task{ + { + TaskKey: "test", + NewCluster: &compute.ClusterSpec{ + SparkVersion: "13.3.x-scala2.12", + NumWorkers: 1, + NodeTypeId: nodeTypeId, + }, + NotebookTask: &jobs.NotebookTask{ + NotebookPath: path.Join(tmpdir, "test"), + }, + }, + }, + }) + require.NoError(t, err) + + return resp.JobId +} + +func destroyJob(t *testing.T, jobId int64) { + w, err := databricks.NewWorkspaceClient() + require.NoError(t, err) + + ctx := context.Background() + err = w.Jobs.Delete(ctx, jobs.DeleteJob{ + JobId: jobId, + }) + require.NoError(t, err) +} diff --git a/libs/dyn/value.go b/libs/dyn/value.go index 0dccb8b7..bbb8ad3e 100644 --- a/libs/dyn/value.go +++ b/libs/dyn/value.go @@ -46,6 +46,10 @@ func (v Value) Kind() Kind { return v.k } +func (v Value) Value() any { + return v.v +} + func (v Value) Location() Location { return v.l } diff --git a/libs/dyn/yamlsaver/order.go b/libs/dyn/yamlsaver/order.go new file mode 100644 index 00000000..439be847 --- /dev/null +++ b/libs/dyn/yamlsaver/order.go @@ -0,0 +1,33 @@ +package yamlsaver + +import "slices" + +// This struct is used to generate indexes for ordering of map keys. +// The ordering defined based on any predefined Order in `Order` field +// or running Order based on `index` +type Order struct { + index int + order []string +} + +func NewOrder(o []string) *Order { + return &Order{index: 0, order: o} +} + +// Returns an integer which represents the order of map key in resulting +// The lower the index, the earlier in the list the key is. +// If the order is not predefined, it uses running order and any subsequential call to +// order.Get returns an increasing index. +func (o *Order) Get(key string) int { + index := slices.Index(o.order, key) + // If the key is found in predefined order list + // We return a negative index which put the value at the top of the order compared to other + // not predefined keys. The earlier value in predefined list, the lower negative index value + if index != -1 { + return index - len(o.order) + } + + // Otherwise we just increase the order index + o.index += 1 + return o.index +} diff --git a/libs/dyn/yamlsaver/order_test.go b/libs/dyn/yamlsaver/order_test.go new file mode 100644 index 00000000..ed2877f6 --- /dev/null +++ b/libs/dyn/yamlsaver/order_test.go @@ -0,0 +1,24 @@ +package yamlsaver + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOrderReturnsIncreasingIndex(t *testing.T) { + o := NewOrder([]string{}) + assert.Equal(t, 1, o.Get("a")) + assert.Equal(t, 2, o.Get("b")) + assert.Equal(t, 3, o.Get("c")) +} + +func TestOrderReturnsNegativeIndexForPredefinedKeys(t *testing.T) { + o := NewOrder([]string{"a", "b", "c"}) + assert.Equal(t, -3, o.Get("a")) + assert.Equal(t, -2, o.Get("b")) + assert.Equal(t, -1, o.Get("c")) + assert.Equal(t, 1, o.Get("d")) + assert.Equal(t, 2, o.Get("e")) + assert.Equal(t, 3, o.Get("f")) +} diff --git a/libs/dyn/yamlsaver/saver.go b/libs/dyn/yamlsaver/saver.go new file mode 100644 index 00000000..f5863ecf --- /dev/null +++ b/libs/dyn/yamlsaver/saver.go @@ -0,0 +1,139 @@ +package yamlsaver + +import ( + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strconv" + + "github.com/databricks/cli/libs/dyn" + "golang.org/x/exp/maps" + "gopkg.in/yaml.v3" +) + +func SaveAsYAML(data any, filename string, force bool) error { + err := os.MkdirAll(filepath.Dir(filename), 0755) + if err != nil { + return err + } + + // check that file exists + info, err := os.Stat(filename) + if err == nil { + if info.IsDir() { + return fmt.Errorf("%s is a directory", filename) + } + if !force { + return fmt.Errorf("%s already exists. Use --force to overwrite", filename) + } + } + + file, err := os.Create(filename) + if err != nil { + return err + } + defer file.Close() + + err = encode(data, file) + if err != nil { + return err + } + return nil +} + +func encode(data any, w io.Writer) error { + yamlNode, err := ToYamlNode(dyn.V(data)) + if err != nil { + return err + } + enc := yaml.NewEncoder(w) + enc.SetIndent(2) + return enc.Encode(yamlNode) +} + +func ToYamlNode(v dyn.Value) (*yaml.Node, error) { + switch v.Kind() { + case dyn.KindMap: + m, _ := v.AsMap() + keys := maps.Keys(m) + // We're using location lines to define the order of keys in YAML. + // The location is set when we convert API response struct to config.Value representation + // See convert.convertMap for details + sort.SliceStable(keys, func(i, j int) bool { + return m[keys[i]].Location().Line < m[keys[j]].Location().Line + }) + + content := make([]*yaml.Node, 0) + for _, k := range keys { + item := m[k] + node := yaml.Node{Kind: yaml.ScalarNode, Value: k} + c, err := ToYamlNode(item) + if err != nil { + return nil, err + } + content = append(content, &node) + content = append(content, c) + } + + return &yaml.Node{Kind: yaml.MappingNode, Content: content}, nil + case dyn.KindSequence: + s, _ := v.AsSequence() + content := make([]*yaml.Node, 0) + for _, item := range s { + node, err := ToYamlNode(item) + if err != nil { + return nil, err + } + content = append(content, node) + } + return &yaml.Node{Kind: yaml.SequenceNode, Content: content}, nil + case dyn.KindNil: + return &yaml.Node{Kind: yaml.ScalarNode, Value: "null"}, nil + case dyn.KindString: + // If the string is a scalar value (bool, int, float and etc.), we want to quote it. + if isScalarValueInString(v) { + return &yaml.Node{Kind: yaml.ScalarNode, Value: v.MustString(), Style: yaml.DoubleQuotedStyle}, nil + } + return &yaml.Node{Kind: yaml.ScalarNode, Value: v.MustString()}, nil + case dyn.KindBool: + return &yaml.Node{Kind: yaml.ScalarNode, Value: fmt.Sprint(v.MustBool())}, nil + case dyn.KindInt: + return &yaml.Node{Kind: yaml.ScalarNode, Value: fmt.Sprint(v.MustInt())}, nil + case dyn.KindFloat: + return &yaml.Node{Kind: yaml.ScalarNode, Value: fmt.Sprint(v.MustFloat())}, nil + case dyn.KindTime: + return &yaml.Node{Kind: yaml.ScalarNode, Value: v.MustTime().UTC().String()}, nil + default: + // Panic because we only want to deal with known types. + panic(fmt.Sprintf("invalid kind: %d", v.Kind())) + } +} + +func isScalarValueInString(v dyn.Value) bool { + if v.Kind() != dyn.KindString { + return false + } + + // Parse value of the string and check if it's a scalar value. + // If it's a scalar value, we want to quote it. + switch v.MustString() { + case "true", "false": + return true + default: + _, err := parseNumber(v.MustString()) + return err == nil + } +} + +func parseNumber(s string) (any, error) { + if i, err := strconv.ParseInt(s, 0, 64); err == nil { + return i, nil + } + + if f, err := strconv.ParseFloat(s, 64); err == nil { + return f, nil + } + return nil, fmt.Errorf("invalid number: %s", s) +} diff --git a/libs/dyn/yamlsaver/saver_test.go b/libs/dyn/yamlsaver/saver_test.go new file mode 100644 index 00000000..70878d55 --- /dev/null +++ b/libs/dyn/yamlsaver/saver_test.go @@ -0,0 +1,195 @@ +package yamlsaver + +import ( + "testing" + "time" + + "github.com/databricks/cli/libs/dyn" + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" +) + +func TestMarshalNilValue(t *testing.T) { + var nilValue = dyn.NilValue + v, err := ToYamlNode(nilValue) + assert.NoError(t, err) + assert.Equal(t, "null", v.Value) +} + +func TestMarshalIntValue(t *testing.T) { + var intValue = dyn.NewValue(1, dyn.Location{}) + v, err := ToYamlNode(intValue) + assert.NoError(t, err) + assert.Equal(t, "1", v.Value) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalFloatValue(t *testing.T) { + var floatValue = dyn.NewValue(1.0, dyn.Location{}) + v, err := ToYamlNode(floatValue) + assert.NoError(t, err) + assert.Equal(t, "1", v.Value) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalBoolValue(t *testing.T) { + var boolValue = dyn.NewValue(true, dyn.Location{}) + v, err := ToYamlNode(boolValue) + assert.NoError(t, err) + assert.Equal(t, "true", v.Value) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalTimeValue(t *testing.T) { + var timeValue = dyn.NewValue(time.Unix(0, 0), dyn.Location{}) + v, err := ToYamlNode(timeValue) + assert.NoError(t, err) + assert.Equal(t, "1970-01-01 00:00:00 +0000 UTC", v.Value) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalSequenceValue(t *testing.T) { + var sequenceValue = dyn.NewValue( + []dyn.Value{ + dyn.NewValue("value1", dyn.Location{File: "file", Line: 1, Column: 2}), + dyn.NewValue("value2", dyn.Location{File: "file", Line: 2, Column: 2}), + }, + dyn.Location{File: "file", Line: 1, Column: 2}, + ) + v, err := ToYamlNode(sequenceValue) + assert.NoError(t, err) + assert.Equal(t, yaml.SequenceNode, v.Kind) + assert.Equal(t, "value1", v.Content[0].Value) + assert.Equal(t, "value2", v.Content[1].Value) +} + +func TestMarshalStringValue(t *testing.T) { + var stringValue = dyn.NewValue("value", dyn.Location{}) + v, err := ToYamlNode(stringValue) + assert.NoError(t, err) + assert.Equal(t, "value", v.Value) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalMapValue(t *testing.T) { + var mapValue = dyn.NewValue( + map[string]dyn.Value{ + "key3": dyn.NewValue("value3", dyn.Location{File: "file", Line: 3, Column: 2}), + "key2": dyn.NewValue("value2", dyn.Location{File: "file", Line: 2, Column: 2}), + "key1": dyn.NewValue("value1", dyn.Location{File: "file", Line: 1, Column: 2}), + }, + dyn.Location{File: "file", Line: 1, Column: 2}, + ) + v, err := ToYamlNode(mapValue) + assert.NoError(t, err) + assert.Equal(t, yaml.MappingNode, v.Kind) + assert.Equal(t, "key1", v.Content[0].Value) + assert.Equal(t, "value1", v.Content[1].Value) + + assert.Equal(t, "key2", v.Content[2].Value) + assert.Equal(t, "value2", v.Content[3].Value) + + assert.Equal(t, "key3", v.Content[4].Value) + assert.Equal(t, "value3", v.Content[5].Value) +} + +func TestMarshalNestedValues(t *testing.T) { + var mapValue = dyn.NewValue( + map[string]dyn.Value{ + "key1": dyn.NewValue( + map[string]dyn.Value{ + "key2": dyn.NewValue("value", dyn.Location{File: "file", Line: 1, Column: 2}), + }, + dyn.Location{File: "file", Line: 1, Column: 2}, + ), + }, + dyn.Location{File: "file", Line: 1, Column: 2}, + ) + v, err := ToYamlNode(mapValue) + assert.NoError(t, err) + assert.Equal(t, yaml.MappingNode, v.Kind) + assert.Equal(t, "key1", v.Content[0].Value) + assert.Equal(t, yaml.MappingNode, v.Content[1].Kind) + assert.Equal(t, "key2", v.Content[1].Content[0].Value) + assert.Equal(t, "value", v.Content[1].Content[1].Value) +} + +func TestMarshalHexadecimalValueIsQuoted(t *testing.T) { + var hexValue = dyn.NewValue(0x123, dyn.Location{}) + v, err := ToYamlNode(hexValue) + assert.NoError(t, err) + assert.Equal(t, "291", v.Value) + assert.Equal(t, yaml.Style(0), v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) + + var stringValue = dyn.NewValue("0x123", dyn.Location{}) + v, err = ToYamlNode(stringValue) + assert.NoError(t, err) + assert.Equal(t, "0x123", v.Value) + assert.Equal(t, yaml.DoubleQuotedStyle, v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalBinaryValueIsQuoted(t *testing.T) { + var binaryValue = dyn.NewValue(0b101, dyn.Location{}) + v, err := ToYamlNode(binaryValue) + assert.NoError(t, err) + assert.Equal(t, "5", v.Value) + assert.Equal(t, yaml.Style(0), v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) + + var stringValue = dyn.NewValue("0b101", dyn.Location{}) + v, err = ToYamlNode(stringValue) + assert.NoError(t, err) + assert.Equal(t, "0b101", v.Value) + assert.Equal(t, yaml.DoubleQuotedStyle, v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalOctalValueIsQuoted(t *testing.T) { + var octalValue = dyn.NewValue(0123, dyn.Location{}) + v, err := ToYamlNode(octalValue) + assert.NoError(t, err) + assert.Equal(t, "83", v.Value) + assert.Equal(t, yaml.Style(0), v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) + + var stringValue = dyn.NewValue("0123", dyn.Location{}) + v, err = ToYamlNode(stringValue) + assert.NoError(t, err) + assert.Equal(t, "0123", v.Value) + assert.Equal(t, yaml.DoubleQuotedStyle, v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalFloatValueIsQuoted(t *testing.T) { + var floatValue = dyn.NewValue(1.0, dyn.Location{}) + v, err := ToYamlNode(floatValue) + assert.NoError(t, err) + assert.Equal(t, "1", v.Value) + assert.Equal(t, yaml.Style(0), v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) + + var stringValue = dyn.NewValue("1.0", dyn.Location{}) + v, err = ToYamlNode(stringValue) + assert.NoError(t, err) + assert.Equal(t, "1.0", v.Value) + assert.Equal(t, yaml.DoubleQuotedStyle, v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalBoolValueIsQuoted(t *testing.T) { + var boolValue = dyn.NewValue(true, dyn.Location{}) + v, err := ToYamlNode(boolValue) + assert.NoError(t, err) + assert.Equal(t, "true", v.Value) + assert.Equal(t, yaml.Style(0), v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) + + var stringValue = dyn.NewValue("true", dyn.Location{}) + v, err = ToYamlNode(stringValue) + assert.NoError(t, err) + assert.Equal(t, "true", v.Value) + assert.Equal(t, yaml.DoubleQuotedStyle, v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} diff --git a/libs/dyn/yamlsaver/utils.go b/libs/dyn/yamlsaver/utils.go new file mode 100644 index 00000000..0fb4064b --- /dev/null +++ b/libs/dyn/yamlsaver/utils.go @@ -0,0 +1,49 @@ +package yamlsaver + +import ( + "fmt" + "slices" + + "github.com/databricks/cli/libs/dyn" + "github.com/databricks/cli/libs/dyn/convert" +) + +// Converts a struct to map. Skips any nil fields. +// It uses `skipFields` to skip unnecessary fields. +// Uses `order` to define the order of keys in resulting outout +func ConvertToMapValue(strct any, order *Order, skipFields []string, dst map[string]dyn.Value) (dyn.Value, error) { + ref := dyn.NilValue + mv, err := convert.FromTyped(strct, ref) + if err != nil { + return dyn.NilValue, err + } + + if mv.Kind() != dyn.KindMap { + return dyn.InvalidValue, fmt.Errorf("expected map, got %s", mv.Kind()) + } + + return skipAndOrder(mv, order, skipFields, dst) +} + +func skipAndOrder(mv dyn.Value, order *Order, skipFields []string, dst map[string]dyn.Value) (dyn.Value, error) { + for k, v := range mv.MustMap() { + if v.Kind() == dyn.KindNil { + continue + } + + if slices.Contains(skipFields, k) { + continue + } + + // If the value is already defined in destination, it means it was + // manually set due to custom ordering or other customisation required + // So we're skipping processing it again + if _, ok := dst[k]; ok { + continue + } + + dst[k] = dyn.NewValue(v.Value(), dyn.Location{Line: order.Get(k)}) + } + + return dyn.V(dst), nil +} diff --git a/libs/dyn/yamlsaver/utils_test.go b/libs/dyn/yamlsaver/utils_test.go new file mode 100644 index 00000000..32c9143b --- /dev/null +++ b/libs/dyn/yamlsaver/utils_test.go @@ -0,0 +1,48 @@ +package yamlsaver + +import ( + "testing" + + "github.com/databricks/cli/libs/dyn" + "github.com/stretchr/testify/assert" +) + +func TestConvertToMapValueWithOrder(t *testing.T) { + type test struct { + Name string `json:"name"` + Map map[string]string `json:"map"` + List []string `json:"list"` + LongNameField string `json:"long_name_field"` + ForceSendFields []string `json:"-"` + Format string `json:"format"` + } + + v := &test{ + Name: "test", + Map: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + List: []string{"a", "b", "c"}, + ForceSendFields: []string{ + "Name", + }, + LongNameField: "long name goes here", + } + result, err := ConvertToMapValue(v, NewOrder([]string{"list", "name", "map"}), []string{"format"}, map[string]dyn.Value{}) + assert.NoError(t, err) + + assert.Equal(t, map[string]dyn.Value{ + "list": dyn.NewValue([]dyn.Value{ + dyn.V("a"), + dyn.V("b"), + dyn.V("c"), + }, dyn.Location{Line: -3}), + "name": dyn.NewValue("test", dyn.Location{Line: -2}), + "map": dyn.NewValue(map[string]dyn.Value{ + "key1": dyn.V("value1"), + "key2": dyn.V("value2"), + }, dyn.Location{Line: -1}), + "long_name_field": dyn.NewValue("long name goes here", dyn.Location{Line: 1}), + }, result.MustMap()) +} diff --git a/libs/notebook/ext.go b/libs/notebook/ext.go new file mode 100644 index 00000000..28d08c11 --- /dev/null +++ b/libs/notebook/ext.go @@ -0,0 +1,23 @@ +package notebook + +import "github.com/databricks/databricks-sdk-go/service/workspace" + +func GetExtensionByLanguage(objectInfo *workspace.ObjectInfo) string { + if objectInfo.ObjectType != workspace.ObjectTypeNotebook { + return "" + } + + switch objectInfo.Language { + case workspace.LanguagePython: + return ".py" + case workspace.LanguageR: + return ".r" + case workspace.LanguageScala: + return ".scala" + case workspace.LanguageSql: + return ".sql" + default: + // Do not add any extension to the file name + return "" + } +} diff --git a/libs/textutil/textutil.go b/libs/textutil/textutil.go new file mode 100644 index 00000000..a5d17d55 --- /dev/null +++ b/libs/textutil/textutil.go @@ -0,0 +1,20 @@ +package textutil + +import ( + "strings" + "unicode" +) + +// We leave the full range of unicode letters in tact, but remove all "special" characters, +// including spaces and dots, which are not supported in e.g. experiment names or YAML keys. +func NormalizeString(name string) string { + name = strings.ToLower(name) + return strings.Map(replaceNonAlphanumeric, name) +} + +func replaceNonAlphanumeric(r rune) rune { + if unicode.IsLetter(r) || unicode.IsDigit(r) { + return r + } + return '_' +} diff --git a/libs/textutil/textutil_test.go b/libs/textutil/textutil_test.go new file mode 100644 index 00000000..fb8bf0b6 --- /dev/null +++ b/libs/textutil/textutil_test.go @@ -0,0 +1,54 @@ +package textutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNormalizeString(t *testing.T) { + cases := []struct { + input string + expected string + }{ + { + input: "test", + expected: "test", + }, + { + input: "test test", + expected: "test_test", + }, + { + input: "test-test", + expected: "test_test", + }, + { + input: "test_test", + expected: "test_test", + }, + { + input: "test.test", + expected: "test_test", + }, + { + input: "test/test", + expected: "test_test", + }, + { + input: "test/test.test", + expected: "test_test_test", + }, + { + input: "TestTest", + expected: "testtest", + }, + { + input: "TestTestTest", + expected: "testtesttest", + }} + + for _, c := range cases { + assert.Equal(t, c.expected, NormalizeString(c.input)) + } +}