more reusable code

This commit is contained in:
Andrew Nester 2023-08-18 11:23:45 +02:00
parent b10d50f521
commit aeb77d94c3
No known key found for this signature in database
GPG Key ID: 12BC628A44B7DA57
4 changed files with 211 additions and 95 deletions

View File

@ -0,0 +1,95 @@
package mutator
import (
"context"
"fmt"
"os"
"path"
"path/filepath"
"text/template"
"github.com/databricks/cli/bundle"
"github.com/databricks/databricks-sdk-go/service/jobs"
)
type fnTemplateData func(task *jobs.Task) (map[string]any, error)
type fnCleanUp func(task *jobs.Task)
type fnTasks func(b *bundle.Bundle) []*jobs.Task
type trampoline struct {
name string
getTasks fnTasks
templateData fnTemplateData
cleanUp fnCleanUp
template string
}
func NewTrampoline(
name string,
tasks fnTasks,
templateData fnTemplateData,
cleanUp fnCleanUp,
template string,
) *trampoline {
return &trampoline{name, tasks, templateData, cleanUp, template}
}
func (m *trampoline) Name() string {
return fmt.Sprintf("trampoline(%s)", m.name)
}
func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error {
tasks := m.getTasks(b)
for _, task := range tasks {
err := m.generateNotebookWrapper(b, task)
if err != nil {
return err
}
}
return nil
}
func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task *jobs.Task) error {
internalDir, err := b.InternalDir()
if err != nil {
return err
}
notebookName := fmt.Sprintf("notebook_%s", task.TaskKey)
localNotebookPath := filepath.Join(internalDir, notebookName+".py")
err = os.MkdirAll(filepath.Dir(localNotebookPath), 0755)
if err != nil {
return err
}
f, err := os.Create(localNotebookPath)
if err != nil {
return err
}
defer f.Close()
data, err := m.templateData(task)
if err != nil {
return err
}
t, err := template.New(notebookName).Parse(m.template)
if err != nil {
return err
}
internalDirRel, err := filepath.Rel(b.Config.Path, internalDir)
if err != nil {
return err
}
m.cleanUp(task)
remotePath := path.Join(b.Config.Workspace.FilesPath, filepath.ToSlash(internalDirRel), notebookName)
task.NotebookTask = &jobs.NotebookTask{
NotebookPath: remotePath,
}
return t.Execute(f, data)
}

View File

@ -0,0 +1,90 @@
package mutator
import (
"context"
"fmt"
"os"
"path/filepath"
"testing"
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/require"
)
func getTasks(b *bundle.Bundle) []*jobs.Task {
tasks := make([]*jobs.Task, 0)
for k := range b.Config.Resources.Jobs["test"].Tasks {
tasks = append(tasks, &b.Config.Resources.Jobs["test"].Tasks[k])
}
return tasks
}
func templateData(task *jobs.Task) (map[string]any, error) {
if task.PythonWheelTask == nil {
return nil, fmt.Errorf("PythonWheelTask cannot be nil")
}
data := make(map[string]any)
data["MyName"] = "Trampoline"
return data, nil
}
func cleanUp(task *jobs.Task) {
task.PythonWheelTask = nil
}
func TestGenerateTrampoline(t *testing.T) {
tmpDir := t.TempDir()
tasks := []jobs.Task{
{
TaskKey: "to_trampoline",
PythonWheelTask: &jobs.PythonWheelTask{
PackageName: "test",
EntryPoint: "run",
}},
}
b := &bundle.Bundle{
Config: config.Root{
Path: tmpDir,
Bundle: config.Bundle{
Target: "development",
},
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"test": {
Paths: resources.Paths{
ConfigFilePath: tmpDir,
},
JobSettings: &jobs.JobSettings{
Tasks: tasks,
},
},
},
},
},
}
ctx := context.Background()
trampoline := NewTrampoline("test_trampoline", getTasks, templateData, cleanUp, "Hello from {{.MyName}}")
err := bundle.Apply(ctx, b, trampoline)
require.NoError(t, err)
dir, err := b.InternalDir()
require.NoError(t, err)
filename := filepath.Join(dir, "notebook_to_trampoline.py")
bytes, err := os.ReadFile(filename)
require.NoError(t, err)
require.Equal(t, "Hello from Trampoline", string(bytes))
task := b.Config.Resources.Jobs["test"].Tasks[0]
require.Equal(t, task.NotebookTask.NotebookPath, ".databricks/bundle/development/.internal/notebook_to_trampoline")
require.Nil(t, task.PythonWheelTask)
}

View File

@ -1,17 +1,13 @@
package python package python
import ( import (
"context"
"fmt" "fmt"
"os" "strconv"
"path"
"path/filepath"
"strings" "strings"
"text/template"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config/mutator"
"github.com/databricks/cli/bundle/libraries" "github.com/databricks/cli/bundle/libraries"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/jobs"
) )
@ -40,105 +36,32 @@ dbutils.notebook.exit(s)
// which installs uploaded wheels using %pip and then calling corresponding // which installs uploaded wheels using %pip and then calling corresponding
// entry point. // entry point.
func TransformWheelTask() bundle.Mutator { func TransformWheelTask() bundle.Mutator {
return &transform{} return mutator.NewTrampoline(
"python_wheel",
getTasks,
generateTemplateData,
cleanUpTask,
NOTEBOOK_TEMPLATE,
)
} }
type transform struct { func getTasks(b *bundle.Bundle) []*jobs.Task {
return libraries.FindAllWheelTasks(b)
} }
func (m *transform) Name() string { func generateTemplateData(task *jobs.Task) (map[string]any, error) {
return "python.TransformWheelTask" params, err := generateParameters(task.PythonWheelTask)
}
func (m *transform) Apply(ctx context.Context, b *bundle.Bundle) error {
wheelTasks := libraries.FindAllWheelTasks(b)
for _, wheelTask := range wheelTasks {
err := generateNotebookTrampoline(b, wheelTask)
if err != nil {
return err
}
}
return nil
}
func generateNotebookTrampoline(b *bundle.Bundle, wheelTask *jobs.Task) error {
taskDefinition := wheelTask.PythonWheelTask
libraries := wheelTask.Libraries
wheelTask.PythonWheelTask = nil
wheelTask.Libraries = nil
filename, err := generateNotebookWrapper(b, taskDefinition, libraries)
if err != nil { if err != nil {
return err return nil, err
}
internalDir, err := getInternalDir(b)
if err != nil {
return err
}
internalDirRel, err := filepath.Rel(b.Config.Path, internalDir)
if err != nil {
return err
}
parts := []string{b.Config.Workspace.FilesPath}
parts = append(parts, strings.Split(internalDirRel, string(os.PathSeparator))...)
parts = append(parts, filename)
wheelTask.NotebookTask = &jobs.NotebookTask{
NotebookPath: path.Join(parts...),
}
return nil
}
func getInternalDir(b *bundle.Bundle) (string, error) {
cacheDir, err := b.CacheDir()
if err != nil {
return "", err
}
internalDir := filepath.Join(cacheDir, ".internal")
return internalDir, nil
}
func generateNotebookWrapper(b *bundle.Bundle, task *jobs.PythonWheelTask, libraries []compute.Library) (string, error) {
internalDir, err := getInternalDir(b)
if err != nil {
return "", err
}
notebookName := fmt.Sprintf("notebook_%s_%s", task.PackageName, task.EntryPoint)
path := filepath.Join(internalDir, notebookName+".py")
err = os.MkdirAll(filepath.Dir(path), 0755)
if err != nil {
return "", err
}
f, err := os.Create(path)
if err != nil {
return "", err
}
defer f.Close()
params, err := generateParameters(task)
if err != nil {
return "", err
} }
data := map[string]any{ data := map[string]any{
"Libraries": libraries, "Libraries": task.Libraries,
"Params": params, "Params": params,
"Task": task, "Task": task.PythonWheelTask,
} }
t, err := template.New("notebook").Parse(NOTEBOOK_TEMPLATE) return data, nil
if err != nil {
return "", err
}
return notebookName, t.Execute(f, data)
} }
func generateParameters(task *jobs.PythonWheelTask) (string, error) { func generateParameters(task *jobs.PythonWheelTask) (string, error) {
@ -149,8 +72,14 @@ func generateParameters(task *jobs.PythonWheelTask) (string, error) {
for k, v := range task.NamedParameters { for k, v := range task.NamedParameters {
params = append(params, fmt.Sprintf("%s=%s", k, v)) params = append(params, fmt.Sprintf("%s=%s", k, v))
} }
for i := range params { for i := range params {
params[i] = `"` + params[i] + `"` params[i] = strconv.Quote(params[i])
} }
return strings.Join(params, ", "), nil return strings.Join(params, ", "), nil
} }
func cleanUpTask(task *jobs.Task) {
task.PythonWheelTask = nil
task.Libraries = nil
}

View File

@ -22,12 +22,14 @@ var paramsTestCases []testCase = []testCase{
{[]string{"a"}, `"python", "a"`}, {[]string{"a"}, `"python", "a"`},
{[]string{"a", "b"}, `"python", "a", "b"`}, {[]string{"a", "b"}, `"python", "a", "b"`},
{[]string{"123!@#$%^&*()-="}, `"python", "123!@#$%^&*()-="`}, {[]string{"123!@#$%^&*()-="}, `"python", "123!@#$%^&*()-="`},
{[]string{`{"a": 1}`}, `"python", "{\"a\": 1}"`},
} }
var paramsTestCasesNamed []testCaseNamed = []testCaseNamed{ var paramsTestCasesNamed []testCaseNamed = []testCaseNamed{
{NamedParams{}, `"python"`}, {NamedParams{}, `"python"`},
{NamedParams{"a": "1"}, `"python", "a=1"`}, {NamedParams{"a": "1"}, `"python", "a=1"`},
{NamedParams{"a": "1", "b": "2"}, `"python", "a=1", "b=2"`}, {NamedParams{"a": "1", "b": "2"}, `"python", "a=1", "b=2"`},
{NamedParams{"data": `{"a": 1}`}, `"python", "data={\"a\": 1}"`},
} }
func TestGenerateParameters(t *testing.T) { func TestGenerateParameters(t *testing.T) {