This commit is contained in:
kartikgupta-db 2023-09-05 13:39:47 +02:00
parent 1d02023a0d
commit 0f0900b081
No known key found for this signature in database
GPG Key ID: 6AD5FA11FACDEA39
5 changed files with 59 additions and 27 deletions

View File

@ -3,6 +3,7 @@ package mutator
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
@ -20,26 +21,19 @@ type TaskWithJobKey struct {
type TrampolineFunctions interface { type TrampolineFunctions interface {
GetTemplateData(b *bundle.Bundle, task *jobs.Task) (map[string]any, error) GetTemplateData(b *bundle.Bundle, task *jobs.Task) (map[string]any, error)
GetTasks(b *bundle.Bundle) []TaskWithJobKey GetTasks(b *bundle.Bundle) []TaskWithJobKey
GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error)
CleanUp(task *jobs.Task) error CleanUp(task *jobs.Task) error
} }
type trampoline struct { type trampoline struct {
name string name string
functions TrampolineFunctions functions TrampolineFunctions
template func(*jobs.Task) (string, error)
} }
func NewTrampoline( func NewTrampoline(
name string, name string,
functions TrampolineFunctions, functions TrampolineFunctions,
template func(*jobs.Task) (string, error),
) *trampoline { ) *trampoline {
return &trampoline{name, functions, template} return &trampoline{name, functions}
}
// Shorthand for generating template function for templates
// that are same irrespective of the task.
func StaticTrampolineTemplate(template string) func(*jobs.Task) (string, error) {
return func(*jobs.Task) (string, error) { return template, nil }
} }
func GetTasksWithJobKeyBy(b *bundle.Bundle, filter func(*jobs.Task) bool) []TaskWithJobKey { func GetTasksWithJobKeyBy(b *bundle.Bundle, filter func(*jobs.Task) bool) []TaskWithJobKey {
@ -64,6 +58,7 @@ func (m *trampoline) Name() string {
func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error { func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error {
tasks := m.functions.GetTasks(b) tasks := m.functions.GetTasks(b)
for _, task := range tasks { for _, task := range tasks {
log.Default().Printf("%s, %s task", task.Task.TaskKey, task.Task.NotebookTask.NotebookPath)
err := m.generateNotebookWrapper(b, task) err := m.generateNotebookWrapper(b, task)
if err != nil { if err != nil {
return err return err
@ -97,7 +92,7 @@ func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task TaskWithJobK
return err return err
} }
templateString, err := m.template(task.Task) templateString, err := m.functions.GetTemplate(b, task.Task)
if err != nil { if err != nil {
return err return err
} }

View File

@ -14,7 +14,8 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type functions struct{} type functions struct {
}
func (f *functions) GetTasks(b *bundle.Bundle) []TaskWithJobKey { func (f *functions) GetTasks(b *bundle.Bundle) []TaskWithJobKey {
tasks := make([]TaskWithJobKey, 0) tasks := make([]TaskWithJobKey, 0)
@ -43,6 +44,10 @@ func (f *functions) CleanUp(task *jobs.Task) error {
return nil return nil
} }
func (f *functions) GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error) {
return "Hello from {{.MyName}}", nil
}
func TestGenerateTrampoline(t *testing.T) { func TestGenerateTrampoline(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
@ -78,7 +83,7 @@ func TestGenerateTrampoline(t *testing.T) {
ctx := context.Background() ctx := context.Background()
funcs := functions{} funcs := functions{}
trampoline := NewTrampoline("test_trampoline", &funcs, StaticTrampolineTemplate("Hello from {{.MyName}}")) trampoline := NewTrampoline("test_trampoline", &funcs)
err := bundle.Apply(ctx, b, trampoline) err := bundle.Apply(ctx, b, trampoline)
require.NoError(t, err) require.NoError(t, err)

View File

@ -8,11 +8,11 @@ def databricks_preamble():
src_file_dir = None src_file_dir = None
project_root_dir = None project_root_dir = None
src_file = {{.SourceFile}} src_file = "{{.SourceFile}}"
src_file_dir = os.path.dirname(src_file) src_file_dir = os.path.dirname(src_file)
os.chdir(src_file_dir) os.chdir(src_file_dir)
project_root_dir = {{.ProjectRoot}} project_root_dir = "{{.ProjectRoot}}"
sys.path.insert(0, project_root_dir) sys.path.insert(0, project_root_dir)
def parse_databricks_magic_lines(lines: List[str]): def parse_databricks_magic_lines(lines: List[str]):

View File

@ -49,7 +49,6 @@ func TransformWheelTask() bundle.Mutator {
return mutator.NewTrampoline( return mutator.NewTrampoline(
"python_wheel", "python_wheel",
&pythonTrampoline{}, &pythonTrampoline{},
mutator.StaticTrampolineTemplate(NOTEBOOK_TEMPLATE),
) )
} }
@ -62,6 +61,10 @@ func (t *pythonTrampoline) CleanUp(task *jobs.Task) error {
return nil return nil
} }
func (t *pythonTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error) {
return NOTEBOOK_TEMPLATE, nil
}
func (t *pythonTrampoline) GetTasks(b *bundle.Bundle) []mutator.TaskWithJobKey { func (t *pythonTrampoline) GetTasks(b *bundle.Bundle) []mutator.TaskWithJobKey {
return mutator.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool { return mutator.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool {
return task.PythonWheelTask != nil return task.PythonWheelTask != nil

View File

@ -1,9 +1,11 @@
package python package python
import ( import (
_ "embed"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
"path/filepath"
"strings" "strings"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
@ -11,28 +13,51 @@ import (
"github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/jobs"
) )
// go:embed trampoline_data/notebook.py //go:embed trampoline_data/notebook.py
var notebookTrampolineData string var notebookTrampolineData string
// go:embed trampoline_data/python.py //go:embed trampoline_data/python.py
var pyTrampolineData string var pyTrampolineData string
func TransforNotebookTask() bundle.Mutator { func TransforNotebookTask() bundle.Mutator {
return mutator.NewTrampoline( return mutator.NewTrampoline(
"python_notebook", "python_notebook",
&notebookTrampoline{}, &notebookTrampoline{},
getTemplate,
) )
} }
type notebookTrampoline struct{} type notebookTrampoline struct{}
func localNotebookPath(b *bundle.Bundle, task *jobs.Task) (string, error) {
remotePath := task.NotebookTask.NotebookPath
relRemotePath, err := filepath.Rel(b.Config.Workspace.FilesPath, remotePath)
if err != nil {
return "", err
}
localPath := filepath.Join(b.Config.Path, filepath.FromSlash(relRemotePath))
_, err = os.Stat(fmt.Sprintf("%s.ipynb", localPath))
if err == nil {
return fmt.Sprintf("%s.ipynb", localPath), nil
}
_, err = os.Stat(fmt.Sprintf("%s.py", localPath))
if err == nil {
return fmt.Sprintf("%s.py", localPath), nil
}
return "", fmt.Errorf("notebook %s not found", localPath)
}
func (n *notebookTrampoline) GetTasks(b *bundle.Bundle) []mutator.TaskWithJobKey { func (n *notebookTrampoline) GetTasks(b *bundle.Bundle) []mutator.TaskWithJobKey {
return mutator.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool { return mutator.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool {
return task.NotebookTask != nil && if task.NotebookTask == nil ||
task.NotebookTask.Source == jobs.SourceWorkspace && task.NotebookTask.Source == jobs.SourceGit {
(strings.HasSuffix(task.NotebookTask.NotebookPath, ".ipynb") || return false
strings.HasSuffix(task.NotebookTask.NotebookPath, ".py")) }
localPath, err := localNotebookPath(b, task)
if err != nil {
return false
}
return strings.HasSuffix(localPath, ".ipynb") || strings.HasSuffix(localPath, ".py")
}) })
} }
@ -40,21 +65,25 @@ func (n *notebookTrampoline) CleanUp(task *jobs.Task) error {
return nil return nil
} }
func getTemplate(task *jobs.Task) (string, error) { func (n *notebookTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error) {
if task.NotebookTask == nil { if task.NotebookTask == nil {
return "", fmt.Errorf("nil notebook path") return "", fmt.Errorf("nil notebook path")
} }
if task.NotebookTask.Source != jobs.SourceWorkspace { if task.NotebookTask.Source == jobs.SourceGit {
return "", fmt.Errorf("source must be workspace") return "", fmt.Errorf("source must be workspace %s", task.NotebookTask.Source)
}
localPath, err := localNotebookPath(b, task)
if err != nil {
return "", err
} }
bytesData, err := os.ReadFile(task.NotebookTask.NotebookPath) bytesData, err := os.ReadFile(localPath)
if err != nil { if err != nil {
return "", err return "", err
} }
s := strings.TrimSpace(string(bytesData)) s := strings.TrimSpace(string(bytesData))
if strings.HasSuffix(task.NotebookTask.NotebookPath, ".ipynb") { if strings.HasSuffix(localPath, ".ipynb") {
return getIpynbTemplate(s) return getIpynbTemplate(s)
} }