diff --git a/bundle/config/mutator/trampoline.go b/bundle/config/mutator/trampoline.go index 11c22708e..7fe6b8430 100644 --- a/bundle/config/mutator/trampoline.go +++ b/bundle/config/mutator/trampoline.go @@ -3,6 +3,7 @@ package mutator import ( "context" "fmt" + "log" "os" "path" "path/filepath" @@ -20,26 +21,19 @@ type TaskWithJobKey struct { type TrampolineFunctions interface { GetTemplateData(b *bundle.Bundle, task *jobs.Task) (map[string]any, error) GetTasks(b *bundle.Bundle) []TaskWithJobKey + GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error) CleanUp(task *jobs.Task) error } type trampoline struct { name string functions TrampolineFunctions - template func(*jobs.Task) (string, error) } func NewTrampoline( name string, functions TrampolineFunctions, - template func(*jobs.Task) (string, error), ) *trampoline { - return &trampoline{name, functions, template} -} - -// Shorthand for generating template function for templates -// that are same irrespective of the task. -func StaticTrampolineTemplate(template string) func(*jobs.Task) (string, error) { - return func(*jobs.Task) (string, error) { return template, nil } + return &trampoline{name, functions} } func GetTasksWithJobKeyBy(b *bundle.Bundle, filter func(*jobs.Task) bool) []TaskWithJobKey { @@ -64,6 +58,7 @@ func (m *trampoline) Name() string { func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error { tasks := m.functions.GetTasks(b) for _, task := range tasks { + log.Default().Printf("%s, %s task", task.Task.TaskKey, task.Task.NotebookTask.NotebookPath) err := m.generateNotebookWrapper(b, task) if err != nil { return err @@ -97,7 +92,7 @@ func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task TaskWithJobK return err } - templateString, err := m.template(task.Task) + templateString, err := m.functions.GetTemplate(b, task.Task) if err != nil { return err } diff --git a/bundle/config/mutator/trampoline_test.go b/bundle/config/mutator/trampoline_test.go index 4bfce9f96..f8a376a7f 100644 --- a/bundle/config/mutator/trampoline_test.go +++ b/bundle/config/mutator/trampoline_test.go @@ -14,7 +14,8 @@ import ( "github.com/stretchr/testify/require" ) -type functions struct{} +type functions struct { +} func (f *functions) GetTasks(b *bundle.Bundle) []TaskWithJobKey { tasks := make([]TaskWithJobKey, 0) @@ -43,6 +44,10 @@ func (f *functions) CleanUp(task *jobs.Task) error { return nil } +func (f *functions) GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error) { + return "Hello from {{.MyName}}", nil +} + func TestGenerateTrampoline(t *testing.T) { tmpDir := t.TempDir() @@ -78,7 +83,7 @@ func TestGenerateTrampoline(t *testing.T) { ctx := context.Background() funcs := functions{} - trampoline := NewTrampoline("test_trampoline", &funcs, StaticTrampolineTemplate("Hello from {{.MyName}}")) + trampoline := NewTrampoline("test_trampoline", &funcs) err := bundle.Apply(ctx, b, trampoline) require.NoError(t, err) diff --git a/bundle/python/trampoline_data/notebook.py b/bundle/python/trampoline_data/notebook.py index 10437cac1..c07d842ec 100644 --- a/bundle/python/trampoline_data/notebook.py +++ b/bundle/python/trampoline_data/notebook.py @@ -8,11 +8,11 @@ def databricks_preamble(): src_file_dir = None project_root_dir = None - src_file = {{.SourceFile}} + src_file = "{{.SourceFile}}" src_file_dir = os.path.dirname(src_file) os.chdir(src_file_dir) - project_root_dir = {{.ProjectRoot}} + project_root_dir = "{{.ProjectRoot}}" sys.path.insert(0, project_root_dir) def parse_databricks_magic_lines(lines: List[str]): diff --git a/bundle/python/transform.go b/bundle/python/transform.go index 627678864..27c790665 100644 --- a/bundle/python/transform.go +++ b/bundle/python/transform.go @@ -49,7 +49,6 @@ func TransformWheelTask() bundle.Mutator { return mutator.NewTrampoline( "python_wheel", &pythonTrampoline{}, - mutator.StaticTrampolineTemplate(NOTEBOOK_TEMPLATE), ) } @@ -62,6 +61,10 @@ func (t *pythonTrampoline) CleanUp(task *jobs.Task) error { return nil } +func (t *pythonTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error) { + return NOTEBOOK_TEMPLATE, nil +} + func (t *pythonTrampoline) GetTasks(b *bundle.Bundle) []mutator.TaskWithJobKey { return mutator.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool { return task.PythonWheelTask != nil diff --git a/bundle/python/workflow_wrappers.go b/bundle/python/workflow_wrappers.go index 00f0818c8..b0f223957 100644 --- a/bundle/python/workflow_wrappers.go +++ b/bundle/python/workflow_wrappers.go @@ -1,9 +1,11 @@ package python import ( + _ "embed" "encoding/json" "fmt" "os" + "path/filepath" "strings" "github.com/databricks/cli/bundle" @@ -11,28 +13,51 @@ import ( "github.com/databricks/databricks-sdk-go/service/jobs" ) -// go:embed trampoline_data/notebook.py +//go:embed trampoline_data/notebook.py var notebookTrampolineData string -// go:embed trampoline_data/python.py +//go:embed trampoline_data/python.py var pyTrampolineData string func TransforNotebookTask() bundle.Mutator { return mutator.NewTrampoline( "python_notebook", ¬ebookTrampoline{}, - getTemplate, ) } type notebookTrampoline struct{} +func localNotebookPath(b *bundle.Bundle, task *jobs.Task) (string, error) { + remotePath := task.NotebookTask.NotebookPath + relRemotePath, err := filepath.Rel(b.Config.Workspace.FilesPath, remotePath) + if err != nil { + return "", err + } + localPath := filepath.Join(b.Config.Path, filepath.FromSlash(relRemotePath)) + _, err = os.Stat(fmt.Sprintf("%s.ipynb", localPath)) + if err == nil { + return fmt.Sprintf("%s.ipynb", localPath), nil + } + + _, err = os.Stat(fmt.Sprintf("%s.py", localPath)) + if err == nil { + return fmt.Sprintf("%s.py", localPath), nil + } + return "", fmt.Errorf("notebook %s not found", localPath) +} + func (n *notebookTrampoline) GetTasks(b *bundle.Bundle) []mutator.TaskWithJobKey { return mutator.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool { - return task.NotebookTask != nil && - task.NotebookTask.Source == jobs.SourceWorkspace && - (strings.HasSuffix(task.NotebookTask.NotebookPath, ".ipynb") || - strings.HasSuffix(task.NotebookTask.NotebookPath, ".py")) + if task.NotebookTask == nil || + task.NotebookTask.Source == jobs.SourceGit { + return false + } + localPath, err := localNotebookPath(b, task) + if err != nil { + return false + } + return strings.HasSuffix(localPath, ".ipynb") || strings.HasSuffix(localPath, ".py") }) } @@ -40,21 +65,25 @@ func (n *notebookTrampoline) CleanUp(task *jobs.Task) error { return nil } -func getTemplate(task *jobs.Task) (string, error) { +func (n *notebookTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error) { if task.NotebookTask == nil { return "", fmt.Errorf("nil notebook path") } - if task.NotebookTask.Source != jobs.SourceWorkspace { - return "", fmt.Errorf("source must be workspace") + if task.NotebookTask.Source == jobs.SourceGit { + return "", fmt.Errorf("source must be workspace %s", task.NotebookTask.Source) + } + localPath, err := localNotebookPath(b, task) + if err != nil { + return "", err } - bytesData, err := os.ReadFile(task.NotebookTask.NotebookPath) + bytesData, err := os.ReadFile(localPath) if err != nil { return "", err } s := strings.TrimSpace(string(bytesData)) - if strings.HasSuffix(task.NotebookTask.NotebookPath, ".ipynb") { + if strings.HasSuffix(localPath, ".ipynb") { return getIpynbTemplate(s) }