From 8ae4bde77327099d0de53f7cbf22d60b67d9cb6b Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Wed, 13 Sep 2023 12:50:29 +0200 Subject: [PATCH] address feedback --- bundle/config/mutator/trampoline.go | 26 ++--------- bundle/python/notebook_task_wrappers.go | 31 +++++-------- bundle/python/wheel_task_wrappers.go | 5 +- libs/jobs/utils.go | 27 +++++++++++ libs/jobs/utils_test.go | 62 +++++++++++++++++++++++++ 5 files changed, 106 insertions(+), 45 deletions(-) create mode 100644 libs/jobs/utils.go create mode 100644 libs/jobs/utils_test.go diff --git a/bundle/config/mutator/trampoline.go b/bundle/config/mutator/trampoline.go index 48ddc796e..ea938fdcf 100644 --- a/bundle/config/mutator/trampoline.go +++ b/bundle/config/mutator/trampoline.go @@ -9,17 +9,13 @@ import ( "text/template" "github.com/databricks/cli/bundle" + jobs_utils "github.com/databricks/cli/libs/jobs" "github.com/databricks/databricks-sdk-go/service/jobs" ) -type TaskWithJobKey struct { - Task *jobs.Task - JobKey string -} - type TrampolineFunctions interface { GetTemplateData(b *bundle.Bundle, task *jobs.Task) (map[string]any, error) - GetTasks(b *bundle.Bundle) []TaskWithJobKey + GetTasks(b *bundle.Bundle) []jobs_utils.TaskWithJobKey GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error) CleanUp(task *jobs.Task) error } @@ -35,22 +31,6 @@ func NewTrampoline( return &trampoline{name, functions} } -func GetTasksWithJobKeyBy(b *bundle.Bundle, filter func(*jobs.Task) bool) []TaskWithJobKey { - tasks := make([]TaskWithJobKey, 0) - for k := range b.Config.Resources.Jobs { - for i := range b.Config.Resources.Jobs[k].Tasks { - t := &b.Config.Resources.Jobs[k].Tasks[i] - if filter(t) { - tasks = append(tasks, TaskWithJobKey{ - JobKey: k, - Task: t, - }) - } - } - } - return tasks -} - func (m *trampoline) Name() string { return fmt.Sprintf("trampoline(%s)", m.name) } @@ -66,7 +46,7 @@ func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error { return nil } -func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task TaskWithJobKey) error { +func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task jobs_utils.TaskWithJobKey) error { internalDir, err := b.InternalDir() if err != nil { return err diff --git a/bundle/python/notebook_task_wrappers.go b/bundle/python/notebook_task_wrappers.go index 7186497f3..304cfdf83 100644 --- a/bundle/python/notebook_task_wrappers.go +++ b/bundle/python/notebook_task_wrappers.go @@ -10,6 +10,7 @@ import ( "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config/mutator" + jobs_utils "github.com/databricks/cli/libs/jobs" "github.com/databricks/databricks-sdk-go/service/jobs" ) @@ -44,20 +45,21 @@ func localNotebookPath(b *bundle.Bundle, task *jobs.Task) (string, error) { if err == nil { return fmt.Sprintf("%s.py", localPath), nil } + return "", fmt.Errorf("notebook %s not found", localPath) } -func (n *notebookTrampoline) GetTasks(b *bundle.Bundle) []mutator.TaskWithJobKey { - return mutator.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool { +func (n *notebookTrampoline) GetTasks(b *bundle.Bundle) []jobs_utils.TaskWithJobKey { + return jobs_utils.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool { if task.NotebookTask == nil || task.NotebookTask.Source == jobs.SourceGit { return false } - localPath, err := localNotebookPath(b, task) - if err != nil { - return false - } - return strings.HasSuffix(localPath, ".ipynb") || strings.HasSuffix(localPath, ".py") + _, err := localNotebookPath(b, task) + // We assume if the notebook is not available locally in the bundle + // then the user has it somewhere in the workspace. For these + // out of bundle notebooks we do not want to write a trampoline. + return err == nil }) } @@ -66,13 +68,6 @@ func (n *notebookTrampoline) CleanUp(task *jobs.Task) error { } func (n *notebookTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error) { - if task.NotebookTask == nil { - return "", fmt.Errorf("nil notebook path") - } - - if task.NotebookTask.Source == jobs.SourceGit { - return "", fmt.Errorf("source must be workspace %s", task.NotebookTask.Source) - } localPath, err := localNotebookPath(b, task) if err != nil { return "", err @@ -89,14 +84,10 @@ func (n *notebookTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (str lines := strings.Split(s, "\n") if strings.HasPrefix(lines[0], "# Databricks notebook source") { - return getDbnbTemplate(strings.Join(lines, "\n")) + return getDbnbTemplate(s) } - return getPyTemplate(s), nil -} - -func getPyTemplate(s string) string { - return pyTrampolineData + return pyTrampolineData, nil } func getDbnbTemplate(s string) (string, error) { diff --git a/bundle/python/wheel_task_wrappers.go b/bundle/python/wheel_task_wrappers.go index c70a01122..df142e4cf 100644 --- a/bundle/python/wheel_task_wrappers.go +++ b/bundle/python/wheel_task_wrappers.go @@ -7,6 +7,7 @@ import ( "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config/mutator" + jobs_utils "github.com/databricks/cli/libs/jobs" "github.com/databricks/databricks-sdk-go/service/jobs" ) @@ -67,8 +68,8 @@ func (t *pythonTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (strin return NOTEBOOK_TEMPLATE, nil } -func (t *pythonTrampoline) GetTasks(b *bundle.Bundle) []mutator.TaskWithJobKey { - return mutator.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool { +func (t *pythonTrampoline) GetTasks(b *bundle.Bundle) []jobs_utils.TaskWithJobKey { + return jobs_utils.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool { return task.PythonWheelTask != nil }) } diff --git a/libs/jobs/utils.go b/libs/jobs/utils.go new file mode 100644 index 000000000..881f4ed8c --- /dev/null +++ b/libs/jobs/utils.go @@ -0,0 +1,27 @@ +package jobs_utils + +import ( + "github.com/databricks/cli/bundle" + "github.com/databricks/databricks-sdk-go/service/jobs" +) + +type TaskWithJobKey struct { + Task *jobs.Task + JobKey string +} + +func GetTasksWithJobKeyBy(b *bundle.Bundle, filter func(*jobs.Task) bool) []TaskWithJobKey { + tasks := make([]TaskWithJobKey, 0) + for k := range b.Config.Resources.Jobs { + for i := range b.Config.Resources.Jobs[k].Tasks { + t := &b.Config.Resources.Jobs[k].Tasks[i] + if filter(t) { + tasks = append(tasks, TaskWithJobKey{ + JobKey: k, + Task: t, + }) + } + } + } + return tasks +} diff --git a/libs/jobs/utils_test.go b/libs/jobs/utils_test.go new file mode 100644 index 000000000..ecbc67dc2 --- /dev/null +++ b/libs/jobs/utils_test.go @@ -0,0 +1,62 @@ +package jobs_utils + +import ( + "testing" + + "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/config/resources" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/require" +) + +func TestCorrectlyFilterTasksByFn(t *testing.T) { + bundle := &bundle.Bundle{ + Config: config.Root{ + Resources: config.Resources{ + Jobs: map[string]*resources.Job{ + "job1": { + JobSettings: &jobs.JobSettings{ + Tasks: []jobs.Task{ + { + TaskKey: "job1_key1", + PythonWheelTask: &jobs.PythonWheelTask{}, + }, + { + TaskKey: "job1_key2", + NotebookTask: &jobs.NotebookTask{}, + }, + }, + }, + }, + "job2": { + JobSettings: &jobs.JobSettings{ + Tasks: []jobs.Task{ + { + TaskKey: "job1_key1", + PythonWheelTask: &jobs.PythonWheelTask{}, + }, + { + TaskKey: "job2_key2", + NotebookTask: &jobs.NotebookTask{}, + }, + }, + }, + }, + }, + }, + }, + } + + tasks := GetTasksWithJobKeyBy(bundle, func(task *jobs.Task) bool { + return task.PythonWheelTask != nil + }) + + require.Len(t, tasks, 2) + + require.Equal(t, "job1", tasks[0].JobKey) + require.Equal(t, "job1_key1", tasks[0].Task.TaskKey) + + require.Equal(t, "job2", tasks[1].JobKey) + require.Equal(t, "job1_key1", tasks[1].Task.TaskKey) +}