diff --git a/bundle/config/generate/job.go b/bundle/config/generate/job.go new file mode 100644 index 00000000..469f8422 --- /dev/null +++ b/bundle/config/generate/job.go @@ -0,0 +1,34 @@ +package generate + +import ( + "github.com/databricks/cli/libs/dyn" + "github.com/databricks/cli/libs/dyn/yamlsaver" + "github.com/databricks/databricks-sdk-go/service/jobs" +) + +var jobOrder = yamlsaver.NewOrder([]string{"name", "job_clusters", "compute", "tasks"}) +var taskOrder = yamlsaver.NewOrder([]string{"task_key", "depends_on", "existing_cluster_id", "new_cluster", "job_cluster_key"}) + +func ConvertJobToValue(job *jobs.Job) (dyn.Value, error) { + value := make(map[string]dyn.Value) + + if job.Settings.Tasks != nil { + tasks := make([]dyn.Value, 0) + for _, task := range job.Settings.Tasks { + v, err := convertTaskToValue(task, taskOrder) + if err != nil { + return dyn.NilValue, err + } + tasks = append(tasks, v) + } + // We're using location lines to define the order of keys in exported YAML. + value["tasks"] = dyn.NewValue(tasks, dyn.Location{Line: jobOrder.Get("tasks")}) + } + + return yamlsaver.ConvertToMapValue(job.Settings, jobOrder, []string{"format", "new_cluster", "existing_cluster_id"}, value) +} + +func convertTaskToValue(task jobs.Task, order *yamlsaver.Order) (dyn.Value, error) { + dst := make(map[string]dyn.Value) + return yamlsaver.ConvertToMapValue(task, order, []string{"format"}, dst) +} diff --git a/bundle/config/mutator/populate_current_user.go b/bundle/config/mutator/populate_current_user.go index 5b5d3096..60587578 100644 --- a/bundle/config/mutator/populate_current_user.go +++ b/bundle/config/mutator/populate_current_user.go @@ -3,11 +3,11 @@ package mutator import ( "context" "strings" - "unicode" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" "github.com/databricks/cli/libs/tags" + "github.com/databricks/cli/libs/textutil" ) type populateCurrentUser struct{} @@ -43,17 +43,10 @@ func (m *populateCurrentUser) Apply(ctx context.Context, b *bundle.Bundle) error return nil } -func replaceNonAlphanumeric(r rune) rune { - if unicode.IsLetter(r) || unicode.IsDigit(r) { - return r - } - return '_' -} - // Get a short-form username, based on the user's primary email address. // We leave the full range of unicode letters in tact, but remove all "special" characters, // including dots, which are not supported in e.g. experiment names. func getShortUserName(emailAddress string) string { local, _, _ := strings.Cut(emailAddress, "@") - return strings.Map(replaceNonAlphanumeric, local) + return textutil.NormalizeString(local) } diff --git a/cmd/bundle/bundle.go b/cmd/bundle/bundle.go index 3206b94e..3aa6945b 100644 --- a/cmd/bundle/bundle.go +++ b/cmd/bundle/bundle.go @@ -22,5 +22,6 @@ func New() *cobra.Command { cmd.AddCommand(newTestCommand()) cmd.AddCommand(newValidateCommand()) cmd.AddCommand(newInitCommand()) + cmd.AddCommand(newGenerateCommand()) return cmd } diff --git a/cmd/bundle/generate.go b/cmd/bundle/generate.go new file mode 100644 index 00000000..a593f52f --- /dev/null +++ b/cmd/bundle/generate.go @@ -0,0 +1,18 @@ +package bundle + +import ( + "github.com/databricks/cli/cmd/bundle/generate" + "github.com/spf13/cobra" +) + +func newGenerateCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "generate", + Short: "Generate bundle configuration", + Long: "Generate bundle configuration", + PreRunE: ConfigureBundleWithVariables, + } + + cmd.AddCommand(generate.NewGenerateJobCommand()) + return cmd +} diff --git a/cmd/bundle/generate/job.go b/cmd/bundle/generate/job.go new file mode 100644 index 00000000..8e186cc3 --- /dev/null +++ b/cmd/bundle/generate/job.go @@ -0,0 +1,91 @@ +package generate + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/config/generate" + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/dyn" + "github.com/databricks/cli/libs/dyn/yamlsaver" + "github.com/databricks/cli/libs/textutil" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/spf13/cobra" +) + +func NewGenerateJobCommand() *cobra.Command { + var configDir string + var sourceDir string + var jobId int64 + var force bool + + cmd := &cobra.Command{ + Use: "job", + Short: "Generate bundle configuration for a job", + PreRunE: root.MustConfigureBundle, + } + + cmd.Flags().Int64Var(&jobId, "existing-job-id", 0, `Job ID of the job to generate config for`) + cmd.MarkFlagRequired("existing-job-id") + + wd, err := os.Getwd() + if err != nil { + wd = "." + } + + cmd.Flags().StringVarP(&configDir, "config-dir", "d", filepath.Join(wd, "resources"), `Dir path where the output config will be stored`) + cmd.Flags().StringVarP(&sourceDir, "source-dir", "s", filepath.Join(wd, "src"), `Dir path where the downloaded files will be stored`) + cmd.Flags().BoolVarP(&force, "force", "f", false, `Force overwrite existing files in the output directory`) + + cmd.RunE = func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + b := bundle.Get(ctx) + w := b.WorkspaceClient() + + job, err := w.Jobs.Get(ctx, jobs.GetJobRequest{JobId: jobId}) + if err != nil { + return err + } + + downloader := newNotebookDownloader(w, sourceDir, configDir) + for _, task := range job.Settings.Tasks { + err := downloader.MarkForDownload(ctx, &task) + if err != nil { + return err + } + } + + v, err := generate.ConvertJobToValue(job) + if err != nil { + return err + } + + jobKey := fmt.Sprintf("job_%s", textutil.NormalizeString(job.Settings.Name)) + result := map[string]dyn.Value{ + "resources": dyn.V(map[string]dyn.Value{ + "jobs": dyn.V(map[string]dyn.Value{ + jobKey: v, + }), + }), + } + + err = downloader.FlushToDisk(ctx, force) + if err != nil { + return err + } + + filename := filepath.Join(configDir, fmt.Sprintf("%s.yml", jobKey)) + err = yamlsaver.SaveAsYAML(result, filename, force) + if err != nil { + return err + } + + cmdio.LogString(ctx, fmt.Sprintf("Job configuration successfully saved to %s", filename)) + return nil + } + + return cmd +} diff --git a/cmd/bundle/generate/utils.go b/cmd/bundle/generate/utils.go new file mode 100644 index 00000000..8d450088 --- /dev/null +++ b/cmd/bundle/generate/utils.go @@ -0,0 +1,107 @@ +package generate + +import ( + "context" + "fmt" + "io" + "os" + "path" + "path/filepath" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/notebook" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/jobs" + "golang.org/x/sync/errgroup" +) + +type notebookDownloader struct { + notebooks map[string]string + w *databricks.WorkspaceClient + sourceDir string + configDir string +} + +func (n *notebookDownloader) MarkForDownload(ctx context.Context, task *jobs.Task) error { + if task.NotebookTask == nil { + return nil + } + + info, err := n.w.Workspace.GetStatusByPath(ctx, task.NotebookTask.NotebookPath) + if err != nil { + return err + } + + ext := notebook.GetExtensionByLanguage(info) + + filename := path.Base(task.NotebookTask.NotebookPath) + ext + targetPath := filepath.Join(n.sourceDir, filename) + + n.notebooks[targetPath] = task.NotebookTask.NotebookPath + + // Update the notebook path to be relative to the config dir + rel, err := filepath.Rel(n.configDir, targetPath) + if err != nil { + return err + } + + task.NotebookTask.NotebookPath = rel + return nil +} + +func (n *notebookDownloader) FlushToDisk(ctx context.Context, force bool) error { + err := os.MkdirAll(n.sourceDir, 0755) + if err != nil { + return err + } + + // First check that all files can be written + for targetPath := range n.notebooks { + info, err := os.Stat(targetPath) + if err == nil { + if info.IsDir() { + return fmt.Errorf("%s is a directory", targetPath) + } + if !force { + return fmt.Errorf("%s already exists. Use --force to overwrite", targetPath) + } + } + } + + errs, errCtx := errgroup.WithContext(ctx) + for k, v := range n.notebooks { + targetPath := k + notebookPath := v + errs.Go(func() error { + reader, err := n.w.Workspace.Download(errCtx, notebookPath) + if err != nil { + return err + } + + file, err := os.Create(targetPath) + if err != nil { + return err + } + defer file.Close() + + _, err = io.Copy(file, reader) + if err != nil { + return err + } + + cmdio.LogString(errCtx, fmt.Sprintf("Notebook successfully saved to %s", targetPath)) + return reader.Close() + }) + } + + return errs.Wait() +} + +func newNotebookDownloader(w *databricks.WorkspaceClient, sourceDir string, configDir string) *notebookDownloader { + return ¬ebookDownloader{ + notebooks: make(map[string]string), + w: w, + sourceDir: sourceDir, + configDir: configDir, + } +} diff --git a/cmd/workspace/workspace/export_dir.go b/cmd/workspace/workspace/export_dir.go index 4f50a96e..d2a86d00 100644 --- a/cmd/workspace/workspace/export_dir.go +++ b/cmd/workspace/workspace/export_dir.go @@ -11,6 +11,7 @@ import ( "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/filer" + "github.com/databricks/cli/libs/notebook" "github.com/databricks/databricks-sdk-go/service/workspace" "github.com/spf13/cobra" ) @@ -47,20 +48,7 @@ func (opts exportDirOptions) callback(ctx context.Context, workspaceFiler filer. return err } objectInfo := info.Sys().(workspace.ObjectInfo) - if objectInfo.ObjectType == workspace.ObjectTypeNotebook { - switch objectInfo.Language { - case workspace.LanguagePython: - targetPath += ".py" - case workspace.LanguageR: - targetPath += ".r" - case workspace.LanguageScala: - targetPath += ".scala" - case workspace.LanguageSql: - targetPath += ".sql" - default: - // Do not add any extension to the file name - } - } + targetPath += notebook.GetExtensionByLanguage(&objectInfo) // Skip file if a file already exists in path. // os.Stat returns a fs.ErrNotExist if a file does not exist at path. diff --git a/internal/bundle/bundles/with_includes/databricks_template_schema.json b/internal/bundle/bundles/with_includes/databricks_template_schema.json new file mode 100644 index 00000000..216bc4c1 --- /dev/null +++ b/internal/bundle/bundles/with_includes/databricks_template_schema.json @@ -0,0 +1,8 @@ +{ + "properties": { + "unique_id": { + "type": "string", + "description": "Unique ID for bundle" + } + } +} diff --git a/internal/bundle/bundles/with_includes/template/databricks.yml.tmpl b/internal/bundle/bundles/with_includes/template/databricks.yml.tmpl new file mode 100644 index 00000000..5d17e0fd --- /dev/null +++ b/internal/bundle/bundles/with_includes/template/databricks.yml.tmpl @@ -0,0 +1,8 @@ +bundle: + name: with_includes + +workspace: + root_path: "~/.bundle/{{.unique_id}}" + +includes: + - resources/*yml diff --git a/internal/bundle/generate_job_test.go b/internal/bundle/generate_job_test.go new file mode 100644 index 00000000..c4f191e7 --- /dev/null +++ b/internal/bundle/generate_job_test.go @@ -0,0 +1,124 @@ +package bundle + +import ( + "context" + "fmt" + "os" + "path" + "path/filepath" + "strings" + "testing" + + "github.com/databricks/cli/internal" + "github.com/databricks/cli/internal/testutil" + "github.com/databricks/cli/libs/filer" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/compute" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func TestAccGenerateFromExistingJobAndDeploy(t *testing.T) { + env := internal.GetEnvOrSkipTest(t, "CLOUD_ENV") + t.Log(env) + + uniqueId := uuid.New().String() + bundleRoot, err := initTestTemplate(t, "with_includes", map[string]any{ + "unique_id": uniqueId, + }) + require.NoError(t, err) + + jobId := createTestJob(t) + t.Cleanup(func() { + destroyJob(t, jobId) + require.NoError(t, err) + }) + + t.Setenv("BUNDLE_ROOT", bundleRoot) + c := internal.NewCobraTestRunner(t, "bundle", "generate", "job", + "--existing-job-id", fmt.Sprint(jobId), + "--config-dir", filepath.Join(bundleRoot, "resources"), + "--source-dir", filepath.Join(bundleRoot, "src")) + _, _, err = c.Run() + require.NoError(t, err) + + _, err = os.Stat(filepath.Join(bundleRoot, "src", "test.py")) + require.NoError(t, err) + + matches, err := filepath.Glob(filepath.Join(bundleRoot, "resources", "job_generated_job_*.yml")) + require.NoError(t, err) + require.Len(t, matches, 1) + + // check the content of generated yaml + data, err := os.ReadFile(matches[0]) + require.NoError(t, err) + generatedYaml := string(data) + require.Contains(t, generatedYaml, "notebook_task:") + require.Contains(t, generatedYaml, "notebook_path: ../src/test.py") + require.Contains(t, generatedYaml, "task_key: test") + require.Contains(t, generatedYaml, "new_cluster:") + require.Contains(t, generatedYaml, "spark_version: 13.3.x-scala2.12") + require.Contains(t, generatedYaml, "num_workers: 1") + + err = deployBundle(t, bundleRoot) + require.NoError(t, err) + + err = destroyBundle(t, bundleRoot) + require.NoError(t, err) + +} + +func createTestJob(t *testing.T) int64 { + var nodeTypeId string + switch testutil.GetCloud(t) { + case testutil.AWS: + nodeTypeId = "i3.xlarge" + case testutil.Azure: + nodeTypeId = "Standard_DS4_v2" + case testutil.GCP: + nodeTypeId = "n1-standard-4" + } + + w, err := databricks.NewWorkspaceClient() + require.NoError(t, err) + + ctx := context.Background() + tmpdir := internal.TemporaryWorkspaceDir(t, w) + f, err := filer.NewWorkspaceFilesClient(w, tmpdir) + require.NoError(t, err) + + err = f.Write(ctx, "test.py", strings.NewReader("# Databricks notebook source\nprint('Hello world!'))")) + require.NoError(t, err) + + resp, err := w.Jobs.Create(ctx, jobs.CreateJob{ + Name: internal.RandomName("generated-job-"), + Tasks: []jobs.Task{ + { + TaskKey: "test", + NewCluster: &compute.ClusterSpec{ + SparkVersion: "13.3.x-scala2.12", + NumWorkers: 1, + NodeTypeId: nodeTypeId, + }, + NotebookTask: &jobs.NotebookTask{ + NotebookPath: path.Join(tmpdir, "test"), + }, + }, + }, + }) + require.NoError(t, err) + + return resp.JobId +} + +func destroyJob(t *testing.T, jobId int64) { + w, err := databricks.NewWorkspaceClient() + require.NoError(t, err) + + ctx := context.Background() + err = w.Jobs.Delete(ctx, jobs.DeleteJob{ + JobId: jobId, + }) + require.NoError(t, err) +} diff --git a/libs/dyn/value.go b/libs/dyn/value.go index 0dccb8b7..bbb8ad3e 100644 --- a/libs/dyn/value.go +++ b/libs/dyn/value.go @@ -46,6 +46,10 @@ func (v Value) Kind() Kind { return v.k } +func (v Value) Value() any { + return v.v +} + func (v Value) Location() Location { return v.l } diff --git a/libs/dyn/yamlsaver/order.go b/libs/dyn/yamlsaver/order.go new file mode 100644 index 00000000..439be847 --- /dev/null +++ b/libs/dyn/yamlsaver/order.go @@ -0,0 +1,33 @@ +package yamlsaver + +import "slices" + +// This struct is used to generate indexes for ordering of map keys. +// The ordering defined based on any predefined Order in `Order` field +// or running Order based on `index` +type Order struct { + index int + order []string +} + +func NewOrder(o []string) *Order { + return &Order{index: 0, order: o} +} + +// Returns an integer which represents the order of map key in resulting +// The lower the index, the earlier in the list the key is. +// If the order is not predefined, it uses running order and any subsequential call to +// order.Get returns an increasing index. +func (o *Order) Get(key string) int { + index := slices.Index(o.order, key) + // If the key is found in predefined order list + // We return a negative index which put the value at the top of the order compared to other + // not predefined keys. The earlier value in predefined list, the lower negative index value + if index != -1 { + return index - len(o.order) + } + + // Otherwise we just increase the order index + o.index += 1 + return o.index +} diff --git a/libs/dyn/yamlsaver/order_test.go b/libs/dyn/yamlsaver/order_test.go new file mode 100644 index 00000000..ed2877f6 --- /dev/null +++ b/libs/dyn/yamlsaver/order_test.go @@ -0,0 +1,24 @@ +package yamlsaver + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOrderReturnsIncreasingIndex(t *testing.T) { + o := NewOrder([]string{}) + assert.Equal(t, 1, o.Get("a")) + assert.Equal(t, 2, o.Get("b")) + assert.Equal(t, 3, o.Get("c")) +} + +func TestOrderReturnsNegativeIndexForPredefinedKeys(t *testing.T) { + o := NewOrder([]string{"a", "b", "c"}) + assert.Equal(t, -3, o.Get("a")) + assert.Equal(t, -2, o.Get("b")) + assert.Equal(t, -1, o.Get("c")) + assert.Equal(t, 1, o.Get("d")) + assert.Equal(t, 2, o.Get("e")) + assert.Equal(t, 3, o.Get("f")) +} diff --git a/libs/dyn/yamlsaver/saver.go b/libs/dyn/yamlsaver/saver.go new file mode 100644 index 00000000..f5863ecf --- /dev/null +++ b/libs/dyn/yamlsaver/saver.go @@ -0,0 +1,139 @@ +package yamlsaver + +import ( + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strconv" + + "github.com/databricks/cli/libs/dyn" + "golang.org/x/exp/maps" + "gopkg.in/yaml.v3" +) + +func SaveAsYAML(data any, filename string, force bool) error { + err := os.MkdirAll(filepath.Dir(filename), 0755) + if err != nil { + return err + } + + // check that file exists + info, err := os.Stat(filename) + if err == nil { + if info.IsDir() { + return fmt.Errorf("%s is a directory", filename) + } + if !force { + return fmt.Errorf("%s already exists. Use --force to overwrite", filename) + } + } + + file, err := os.Create(filename) + if err != nil { + return err + } + defer file.Close() + + err = encode(data, file) + if err != nil { + return err + } + return nil +} + +func encode(data any, w io.Writer) error { + yamlNode, err := ToYamlNode(dyn.V(data)) + if err != nil { + return err + } + enc := yaml.NewEncoder(w) + enc.SetIndent(2) + return enc.Encode(yamlNode) +} + +func ToYamlNode(v dyn.Value) (*yaml.Node, error) { + switch v.Kind() { + case dyn.KindMap: + m, _ := v.AsMap() + keys := maps.Keys(m) + // We're using location lines to define the order of keys in YAML. + // The location is set when we convert API response struct to config.Value representation + // See convert.convertMap for details + sort.SliceStable(keys, func(i, j int) bool { + return m[keys[i]].Location().Line < m[keys[j]].Location().Line + }) + + content := make([]*yaml.Node, 0) + for _, k := range keys { + item := m[k] + node := yaml.Node{Kind: yaml.ScalarNode, Value: k} + c, err := ToYamlNode(item) + if err != nil { + return nil, err + } + content = append(content, &node) + content = append(content, c) + } + + return &yaml.Node{Kind: yaml.MappingNode, Content: content}, nil + case dyn.KindSequence: + s, _ := v.AsSequence() + content := make([]*yaml.Node, 0) + for _, item := range s { + node, err := ToYamlNode(item) + if err != nil { + return nil, err + } + content = append(content, node) + } + return &yaml.Node{Kind: yaml.SequenceNode, Content: content}, nil + case dyn.KindNil: + return &yaml.Node{Kind: yaml.ScalarNode, Value: "null"}, nil + case dyn.KindString: + // If the string is a scalar value (bool, int, float and etc.), we want to quote it. + if isScalarValueInString(v) { + return &yaml.Node{Kind: yaml.ScalarNode, Value: v.MustString(), Style: yaml.DoubleQuotedStyle}, nil + } + return &yaml.Node{Kind: yaml.ScalarNode, Value: v.MustString()}, nil + case dyn.KindBool: + return &yaml.Node{Kind: yaml.ScalarNode, Value: fmt.Sprint(v.MustBool())}, nil + case dyn.KindInt: + return &yaml.Node{Kind: yaml.ScalarNode, Value: fmt.Sprint(v.MustInt())}, nil + case dyn.KindFloat: + return &yaml.Node{Kind: yaml.ScalarNode, Value: fmt.Sprint(v.MustFloat())}, nil + case dyn.KindTime: + return &yaml.Node{Kind: yaml.ScalarNode, Value: v.MustTime().UTC().String()}, nil + default: + // Panic because we only want to deal with known types. + panic(fmt.Sprintf("invalid kind: %d", v.Kind())) + } +} + +func isScalarValueInString(v dyn.Value) bool { + if v.Kind() != dyn.KindString { + return false + } + + // Parse value of the string and check if it's a scalar value. + // If it's a scalar value, we want to quote it. + switch v.MustString() { + case "true", "false": + return true + default: + _, err := parseNumber(v.MustString()) + return err == nil + } +} + +func parseNumber(s string) (any, error) { + if i, err := strconv.ParseInt(s, 0, 64); err == nil { + return i, nil + } + + if f, err := strconv.ParseFloat(s, 64); err == nil { + return f, nil + } + return nil, fmt.Errorf("invalid number: %s", s) +} diff --git a/libs/dyn/yamlsaver/saver_test.go b/libs/dyn/yamlsaver/saver_test.go new file mode 100644 index 00000000..70878d55 --- /dev/null +++ b/libs/dyn/yamlsaver/saver_test.go @@ -0,0 +1,195 @@ +package yamlsaver + +import ( + "testing" + "time" + + "github.com/databricks/cli/libs/dyn" + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" +) + +func TestMarshalNilValue(t *testing.T) { + var nilValue = dyn.NilValue + v, err := ToYamlNode(nilValue) + assert.NoError(t, err) + assert.Equal(t, "null", v.Value) +} + +func TestMarshalIntValue(t *testing.T) { + var intValue = dyn.NewValue(1, dyn.Location{}) + v, err := ToYamlNode(intValue) + assert.NoError(t, err) + assert.Equal(t, "1", v.Value) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalFloatValue(t *testing.T) { + var floatValue = dyn.NewValue(1.0, dyn.Location{}) + v, err := ToYamlNode(floatValue) + assert.NoError(t, err) + assert.Equal(t, "1", v.Value) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalBoolValue(t *testing.T) { + var boolValue = dyn.NewValue(true, dyn.Location{}) + v, err := ToYamlNode(boolValue) + assert.NoError(t, err) + assert.Equal(t, "true", v.Value) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalTimeValue(t *testing.T) { + var timeValue = dyn.NewValue(time.Unix(0, 0), dyn.Location{}) + v, err := ToYamlNode(timeValue) + assert.NoError(t, err) + assert.Equal(t, "1970-01-01 00:00:00 +0000 UTC", v.Value) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalSequenceValue(t *testing.T) { + var sequenceValue = dyn.NewValue( + []dyn.Value{ + dyn.NewValue("value1", dyn.Location{File: "file", Line: 1, Column: 2}), + dyn.NewValue("value2", dyn.Location{File: "file", Line: 2, Column: 2}), + }, + dyn.Location{File: "file", Line: 1, Column: 2}, + ) + v, err := ToYamlNode(sequenceValue) + assert.NoError(t, err) + assert.Equal(t, yaml.SequenceNode, v.Kind) + assert.Equal(t, "value1", v.Content[0].Value) + assert.Equal(t, "value2", v.Content[1].Value) +} + +func TestMarshalStringValue(t *testing.T) { + var stringValue = dyn.NewValue("value", dyn.Location{}) + v, err := ToYamlNode(stringValue) + assert.NoError(t, err) + assert.Equal(t, "value", v.Value) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalMapValue(t *testing.T) { + var mapValue = dyn.NewValue( + map[string]dyn.Value{ + "key3": dyn.NewValue("value3", dyn.Location{File: "file", Line: 3, Column: 2}), + "key2": dyn.NewValue("value2", dyn.Location{File: "file", Line: 2, Column: 2}), + "key1": dyn.NewValue("value1", dyn.Location{File: "file", Line: 1, Column: 2}), + }, + dyn.Location{File: "file", Line: 1, Column: 2}, + ) + v, err := ToYamlNode(mapValue) + assert.NoError(t, err) + assert.Equal(t, yaml.MappingNode, v.Kind) + assert.Equal(t, "key1", v.Content[0].Value) + assert.Equal(t, "value1", v.Content[1].Value) + + assert.Equal(t, "key2", v.Content[2].Value) + assert.Equal(t, "value2", v.Content[3].Value) + + assert.Equal(t, "key3", v.Content[4].Value) + assert.Equal(t, "value3", v.Content[5].Value) +} + +func TestMarshalNestedValues(t *testing.T) { + var mapValue = dyn.NewValue( + map[string]dyn.Value{ + "key1": dyn.NewValue( + map[string]dyn.Value{ + "key2": dyn.NewValue("value", dyn.Location{File: "file", Line: 1, Column: 2}), + }, + dyn.Location{File: "file", Line: 1, Column: 2}, + ), + }, + dyn.Location{File: "file", Line: 1, Column: 2}, + ) + v, err := ToYamlNode(mapValue) + assert.NoError(t, err) + assert.Equal(t, yaml.MappingNode, v.Kind) + assert.Equal(t, "key1", v.Content[0].Value) + assert.Equal(t, yaml.MappingNode, v.Content[1].Kind) + assert.Equal(t, "key2", v.Content[1].Content[0].Value) + assert.Equal(t, "value", v.Content[1].Content[1].Value) +} + +func TestMarshalHexadecimalValueIsQuoted(t *testing.T) { + var hexValue = dyn.NewValue(0x123, dyn.Location{}) + v, err := ToYamlNode(hexValue) + assert.NoError(t, err) + assert.Equal(t, "291", v.Value) + assert.Equal(t, yaml.Style(0), v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) + + var stringValue = dyn.NewValue("0x123", dyn.Location{}) + v, err = ToYamlNode(stringValue) + assert.NoError(t, err) + assert.Equal(t, "0x123", v.Value) + assert.Equal(t, yaml.DoubleQuotedStyle, v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalBinaryValueIsQuoted(t *testing.T) { + var binaryValue = dyn.NewValue(0b101, dyn.Location{}) + v, err := ToYamlNode(binaryValue) + assert.NoError(t, err) + assert.Equal(t, "5", v.Value) + assert.Equal(t, yaml.Style(0), v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) + + var stringValue = dyn.NewValue("0b101", dyn.Location{}) + v, err = ToYamlNode(stringValue) + assert.NoError(t, err) + assert.Equal(t, "0b101", v.Value) + assert.Equal(t, yaml.DoubleQuotedStyle, v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalOctalValueIsQuoted(t *testing.T) { + var octalValue = dyn.NewValue(0123, dyn.Location{}) + v, err := ToYamlNode(octalValue) + assert.NoError(t, err) + assert.Equal(t, "83", v.Value) + assert.Equal(t, yaml.Style(0), v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) + + var stringValue = dyn.NewValue("0123", dyn.Location{}) + v, err = ToYamlNode(stringValue) + assert.NoError(t, err) + assert.Equal(t, "0123", v.Value) + assert.Equal(t, yaml.DoubleQuotedStyle, v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalFloatValueIsQuoted(t *testing.T) { + var floatValue = dyn.NewValue(1.0, dyn.Location{}) + v, err := ToYamlNode(floatValue) + assert.NoError(t, err) + assert.Equal(t, "1", v.Value) + assert.Equal(t, yaml.Style(0), v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) + + var stringValue = dyn.NewValue("1.0", dyn.Location{}) + v, err = ToYamlNode(stringValue) + assert.NoError(t, err) + assert.Equal(t, "1.0", v.Value) + assert.Equal(t, yaml.DoubleQuotedStyle, v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} + +func TestMarshalBoolValueIsQuoted(t *testing.T) { + var boolValue = dyn.NewValue(true, dyn.Location{}) + v, err := ToYamlNode(boolValue) + assert.NoError(t, err) + assert.Equal(t, "true", v.Value) + assert.Equal(t, yaml.Style(0), v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) + + var stringValue = dyn.NewValue("true", dyn.Location{}) + v, err = ToYamlNode(stringValue) + assert.NoError(t, err) + assert.Equal(t, "true", v.Value) + assert.Equal(t, yaml.DoubleQuotedStyle, v.Style) + assert.Equal(t, yaml.ScalarNode, v.Kind) +} diff --git a/libs/dyn/yamlsaver/utils.go b/libs/dyn/yamlsaver/utils.go new file mode 100644 index 00000000..0fb4064b --- /dev/null +++ b/libs/dyn/yamlsaver/utils.go @@ -0,0 +1,49 @@ +package yamlsaver + +import ( + "fmt" + "slices" + + "github.com/databricks/cli/libs/dyn" + "github.com/databricks/cli/libs/dyn/convert" +) + +// Converts a struct to map. Skips any nil fields. +// It uses `skipFields` to skip unnecessary fields. +// Uses `order` to define the order of keys in resulting outout +func ConvertToMapValue(strct any, order *Order, skipFields []string, dst map[string]dyn.Value) (dyn.Value, error) { + ref := dyn.NilValue + mv, err := convert.FromTyped(strct, ref) + if err != nil { + return dyn.NilValue, err + } + + if mv.Kind() != dyn.KindMap { + return dyn.InvalidValue, fmt.Errorf("expected map, got %s", mv.Kind()) + } + + return skipAndOrder(mv, order, skipFields, dst) +} + +func skipAndOrder(mv dyn.Value, order *Order, skipFields []string, dst map[string]dyn.Value) (dyn.Value, error) { + for k, v := range mv.MustMap() { + if v.Kind() == dyn.KindNil { + continue + } + + if slices.Contains(skipFields, k) { + continue + } + + // If the value is already defined in destination, it means it was + // manually set due to custom ordering or other customisation required + // So we're skipping processing it again + if _, ok := dst[k]; ok { + continue + } + + dst[k] = dyn.NewValue(v.Value(), dyn.Location{Line: order.Get(k)}) + } + + return dyn.V(dst), nil +} diff --git a/libs/dyn/yamlsaver/utils_test.go b/libs/dyn/yamlsaver/utils_test.go new file mode 100644 index 00000000..32c9143b --- /dev/null +++ b/libs/dyn/yamlsaver/utils_test.go @@ -0,0 +1,48 @@ +package yamlsaver + +import ( + "testing" + + "github.com/databricks/cli/libs/dyn" + "github.com/stretchr/testify/assert" +) + +func TestConvertToMapValueWithOrder(t *testing.T) { + type test struct { + Name string `json:"name"` + Map map[string]string `json:"map"` + List []string `json:"list"` + LongNameField string `json:"long_name_field"` + ForceSendFields []string `json:"-"` + Format string `json:"format"` + } + + v := &test{ + Name: "test", + Map: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + List: []string{"a", "b", "c"}, + ForceSendFields: []string{ + "Name", + }, + LongNameField: "long name goes here", + } + result, err := ConvertToMapValue(v, NewOrder([]string{"list", "name", "map"}), []string{"format"}, map[string]dyn.Value{}) + assert.NoError(t, err) + + assert.Equal(t, map[string]dyn.Value{ + "list": dyn.NewValue([]dyn.Value{ + dyn.V("a"), + dyn.V("b"), + dyn.V("c"), + }, dyn.Location{Line: -3}), + "name": dyn.NewValue("test", dyn.Location{Line: -2}), + "map": dyn.NewValue(map[string]dyn.Value{ + "key1": dyn.V("value1"), + "key2": dyn.V("value2"), + }, dyn.Location{Line: -1}), + "long_name_field": dyn.NewValue("long name goes here", dyn.Location{Line: 1}), + }, result.MustMap()) +} diff --git a/libs/notebook/ext.go b/libs/notebook/ext.go new file mode 100644 index 00000000..28d08c11 --- /dev/null +++ b/libs/notebook/ext.go @@ -0,0 +1,23 @@ +package notebook + +import "github.com/databricks/databricks-sdk-go/service/workspace" + +func GetExtensionByLanguage(objectInfo *workspace.ObjectInfo) string { + if objectInfo.ObjectType != workspace.ObjectTypeNotebook { + return "" + } + + switch objectInfo.Language { + case workspace.LanguagePython: + return ".py" + case workspace.LanguageR: + return ".r" + case workspace.LanguageScala: + return ".scala" + case workspace.LanguageSql: + return ".sql" + default: + // Do not add any extension to the file name + return "" + } +} diff --git a/libs/textutil/textutil.go b/libs/textutil/textutil.go new file mode 100644 index 00000000..a5d17d55 --- /dev/null +++ b/libs/textutil/textutil.go @@ -0,0 +1,20 @@ +package textutil + +import ( + "strings" + "unicode" +) + +// We leave the full range of unicode letters in tact, but remove all "special" characters, +// including spaces and dots, which are not supported in e.g. experiment names or YAML keys. +func NormalizeString(name string) string { + name = strings.ToLower(name) + return strings.Map(replaceNonAlphanumeric, name) +} + +func replaceNonAlphanumeric(r rune) rune { + if unicode.IsLetter(r) || unicode.IsDigit(r) { + return r + } + return '_' +} diff --git a/libs/textutil/textutil_test.go b/libs/textutil/textutil_test.go new file mode 100644 index 00000000..fb8bf0b6 --- /dev/null +++ b/libs/textutil/textutil_test.go @@ -0,0 +1,54 @@ +package textutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNormalizeString(t *testing.T) { + cases := []struct { + input string + expected string + }{ + { + input: "test", + expected: "test", + }, + { + input: "test test", + expected: "test_test", + }, + { + input: "test-test", + expected: "test_test", + }, + { + input: "test_test", + expected: "test_test", + }, + { + input: "test.test", + expected: "test_test", + }, + { + input: "test/test", + expected: "test_test", + }, + { + input: "test/test.test", + expected: "test_test_test", + }, + { + input: "TestTest", + expected: "testtest", + }, + { + input: "TestTestTest", + expected: "testtesttest", + }} + + for _, c := range cases { + assert.Equal(t, c.expected, NormalizeString(c.input)) + } +}