This commit is contained in:
kartikgupta-db 2023-08-31 14:04:04 +02:00
parent 8d3381e0e0
commit 1d02023a0d
No known key found for this signature in database
GPG Key ID: 6AD5FA11FACDEA39
7 changed files with 235 additions and 29 deletions

View File

@ -18,24 +18,45 @@ type TaskWithJobKey struct {
}
type TrampolineFunctions interface {
GetTemplateData(task *jobs.Task) (map[string]any, error)
GetTemplateData(b *bundle.Bundle, task *jobs.Task) (map[string]any, error)
GetTasks(b *bundle.Bundle) []TaskWithJobKey
CleanUp(task *jobs.Task) error
}
type trampoline struct {
name string
functions TrampolineFunctions
template string
template func(*jobs.Task) (string, error)
}
func NewTrampoline(
name string,
functions TrampolineFunctions,
template string,
template func(*jobs.Task) (string, error),
) *trampoline {
return &trampoline{name, functions, template}
}
// Shorthand for generating template function for templates
// that are same irrespective of the task.
func StaticTrampolineTemplate(template string) func(*jobs.Task) (string, error) {
return func(*jobs.Task) (string, error) { return template, nil }
}
func GetTasksWithJobKeyBy(b *bundle.Bundle, filter func(*jobs.Task) bool) []TaskWithJobKey {
tasks := make([]TaskWithJobKey, 0)
for k := range b.Config.Resources.Jobs {
for _, t := range b.Config.Resources.Jobs[k].Tasks {
if filter(&t) {
tasks = append(tasks, TaskWithJobKey{
JobKey: k,
Task: &t,
})
}
}
}
return tasks
}
func (m *trampoline) Name() string {
return fmt.Sprintf("trampoline(%s)", m.name)
}
@ -57,7 +78,7 @@ func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task TaskWithJobK
return err
}
notebookName := fmt.Sprintf("notebook_%s_%s", task.JobKey, task.Task.TaskKey)
notebookName := fmt.Sprintf("notebook_%s_%s_%s", m.name, task.JobKey, task.Task.TaskKey)
localNotebookPath := filepath.Join(internalDir, notebookName+".py")
err = os.MkdirAll(filepath.Dir(localNotebookPath), 0755)
@ -71,12 +92,16 @@ func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task TaskWithJobK
}
defer f.Close()
data, err := m.functions.GetTemplateData(task.Task)
data, err := m.functions.GetTemplateData(b, task.Task)
if err != nil {
return err
}
t, err := template.New(notebookName).Parse(m.template)
templateString, err := m.template(task.Task)
if err != nil {
return err
}
t, err := template.New(notebookName).Parse(templateString)
if err != nil {
return err
}

View File

@ -28,7 +28,7 @@ func (f *functions) GetTasks(b *bundle.Bundle) []TaskWithJobKey {
return tasks
}
func (f *functions) GetTemplateData(task *jobs.Task) (map[string]any, error) {
func (f *functions) GetTemplateData(_ *bundle.Bundle, task *jobs.Task) (map[string]any, error) {
if task.PythonWheelTask == nil {
return nil, fmt.Errorf("PythonWheelTask cannot be nil")
}
@ -78,7 +78,7 @@ func TestGenerateTrampoline(t *testing.T) {
ctx := context.Background()
funcs := functions{}
trampoline := NewTrampoline("test_trampoline", &funcs, "Hello from {{.MyName}}")
trampoline := NewTrampoline("test_trampoline", &funcs, StaticTrampolineTemplate("Hello from {{.MyName}}"))
err := bundle.Apply(ctx, b, trampoline)
require.NoError(t, err)

View File

@ -22,6 +22,7 @@ func Deploy() bundle.Mutator {
artifacts.CleanUp(),
artifacts.UploadAll(),
python.TransformWheelTask(),
python.TransforNotebookTask(),
files.Upload(),
terraform.Interpolate(),
terraform.Write(),

View File

@ -0,0 +1,50 @@
# This cell is autogenerated by the Databricks Extension for VS Code
def databricks_preamble():
from IPython import get_ipython
from typing import List
from shlex import quote
import os
src_file_dir = None
project_root_dir = None
src_file = {{.SourceFile}}
src_file_dir = os.path.dirname(src_file)
os.chdir(src_file_dir)
project_root_dir = {{.ProjectRoot}}
sys.path.insert(0, project_root_dir)
def parse_databricks_magic_lines(lines: List[str]):
if len(lines) == 0 or src_file_dir is None:
return lines
first = ""
for line in lines:
if len(line.strip()) != 0:
first = line
break
if first.startswith("%"):
magic = first.split(" ")[0].strip().strip("%")
rest = ' '.join(first.split(" ")[1:])
if magic == "sh":
return [
"%sh\n",
f"cd {quote(src_file_dir)}\n",
rest.strip() + "\n",
*lines[1:]
]
return lines
ip = get_ipython()
ip.input_transformers_cleanup.append(parse_databricks_magic_lines)
try:
databricks_preamble()
del databricks_preamble
except Exception as e:
print("Error in databricks_preamble: " + str(e))

View File

@ -0,0 +1,36 @@
#This file is autogenerated by the Databricks Extension for VS Code
import runpy
import sys
import os
python_file = {{.PythonFile}}
project_root = {{.ProjectRoot}}
#remove databricks args from argv
sys.argv = sys.argv[1:]
# change working directory
os.chdir(os.path.dirname(python_file))
# update python path
sys.path.insert(0, project_root)
# provide spark globals
user_ns = {
"display": display,
"displayHTML": displayHTML,
"dbutils": dbutils,
"table": table,
"sql": sql,
"udf": udf,
"getArgument": getArgument,
"sc": sc,
"spark": spark,
"sqlContext": sqlContext,
}
# Set log level to "ERROR". See https://kb.databricks.com/notebooks/cmd-c-on-object-id-p0.html
import logging; logger = spark._jvm.org.apache.log4j;
logging.getLogger("py4j.java_gateway").setLevel(logging.ERROR)
runpy.run_path(python_file, run_name="__main__", init_globals=user_ns)
None

View File

@ -49,7 +49,7 @@ func TransformWheelTask() bundle.Mutator {
return mutator.NewTrampoline(
"python_wheel",
&pythonTrampoline{},
NOTEBOOK_TEMPLATE,
mutator.StaticTrampolineTemplate(NOTEBOOK_TEMPLATE),
)
}
@ -63,28 +63,12 @@ func (t *pythonTrampoline) CleanUp(task *jobs.Task) error {
}
func (t *pythonTrampoline) GetTasks(b *bundle.Bundle) []mutator.TaskWithJobKey {
r := b.Config.Resources
result := make([]mutator.TaskWithJobKey, 0)
for k := range b.Config.Resources.Jobs {
tasks := r.Jobs[k].JobSettings.Tasks
for i := range tasks {
task := &tasks[i]
// Keep only Python wheel tasks
if task.PythonWheelTask == nil {
continue
}
result = append(result, mutator.TaskWithJobKey{
JobKey: k,
Task: task,
})
}
}
return result
return mutator.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool {
return task.PythonWheelTask != nil
})
}
func (t *pythonTrampoline) GetTemplateData(task *jobs.Task) (map[string]any, error) {
func (t *pythonTrampoline) GetTemplateData(_ *bundle.Bundle, task *jobs.Task) (map[string]any, error) {
params, err := t.generateParameters(task.PythonWheelTask)
if err != nil {
return nil, err

View File

@ -0,0 +1,110 @@
package python
import (
"encoding/json"
"fmt"
"os"
"strings"
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config/mutator"
"github.com/databricks/databricks-sdk-go/service/jobs"
)
// go:embed trampoline_data/notebook.py
var notebookTrampolineData string
// go:embed trampoline_data/python.py
var pyTrampolineData string
func TransforNotebookTask() bundle.Mutator {
return mutator.NewTrampoline(
"python_notebook",
&notebookTrampoline{},
getTemplate,
)
}
type notebookTrampoline struct{}
func (n *notebookTrampoline) GetTasks(b *bundle.Bundle) []mutator.TaskWithJobKey {
return mutator.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool {
return task.NotebookTask != nil &&
task.NotebookTask.Source == jobs.SourceWorkspace &&
(strings.HasSuffix(task.NotebookTask.NotebookPath, ".ipynb") ||
strings.HasSuffix(task.NotebookTask.NotebookPath, ".py"))
})
}
func (n *notebookTrampoline) CleanUp(task *jobs.Task) error {
return nil
}
func getTemplate(task *jobs.Task) (string, error) {
if task.NotebookTask == nil {
return "", fmt.Errorf("nil notebook path")
}
if task.NotebookTask.Source != jobs.SourceWorkspace {
return "", fmt.Errorf("source must be workspace")
}
bytesData, err := os.ReadFile(task.NotebookTask.NotebookPath)
if err != nil {
return "", err
}
s := strings.TrimSpace(string(bytesData))
if strings.HasSuffix(task.NotebookTask.NotebookPath, ".ipynb") {
return getIpynbTemplate(s)
}
lines := strings.Split(s, "\n")
if strings.HasPrefix(lines[0], "# Databricks notebook source") {
return getDbnbTemplate(strings.Join(lines[1:], "\n"))
}
//TODO return getPyTemplate(s), nil
return s, nil
}
func getDbnbTemplate(s string) (string, error) {
s = strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(s), "# Databricks notebook source"))
return fmt.Sprintf(`# Databricks notebook source
%s
# Command ----------
%s
`, notebookTrampolineData, s), nil
}
func getIpynbTemplate(s string) (string, error) {
var data map[string]any
err := json.Unmarshal([]byte(s), &data)
if err != nil {
return "", err
}
if data["cells"] == nil {
data["cells"] = []any{}
}
data["cells"] = append([]any{
map[string]any{
"cell_type": "code",
"source": []string{notebookTrampolineData},
},
}, data["cells"].([]any)...)
bytes, err := json.Marshal(data)
if err != nil {
return "", err
}
return string(bytes), nil
}
func (n *notebookTrampoline) GetTemplateData(b *bundle.Bundle, task *jobs.Task) (map[string]any, error) {
return map[string]any{
"ProjectRoot": b.Config.Workspace.FilesPath,
"SourceFile": task.NotebookTask.NotebookPath,
}, nil
}