diff --git a/bundle/artifacts/artifacts.go b/bundle/artifacts/artifacts.go index c5413121..73a6d447 100644 --- a/bundle/artifacts/artifacts.go +++ b/bundle/artifacts/artifacts.go @@ -165,3 +165,36 @@ func getUploadBasePath(b *bundle.Bundle) (string, error) { return path.Join(artifactPath, ".internal"), nil } + +func UploadNotebook(ctx context.Context, notebook string, b *bundle.Bundle) (string, error) { + raw, err := os.ReadFile(notebook) + if err != nil { + return "", fmt.Errorf("unable to read %s: %w", notebook, errors.Unwrap(err)) + } + + uploadPath, err := getUploadBasePath(b) + if err != nil { + return "", err + } + + remotePath := path.Join(uploadPath, path.Base(notebook)) + // Make sure target directory exists. + err = b.WorkspaceClient().Workspace.MkdirsByPath(ctx, path.Dir(remotePath)) + if err != nil { + return "", fmt.Errorf("unable to create directory for %s: %w", remotePath, err) + } + + // Import to workspace. + err = b.WorkspaceClient().Workspace.Import(ctx, workspace.Import{ + Path: remotePath, + Overwrite: true, + Format: workspace.ImportFormatSource, + Content: base64.StdEncoding.EncodeToString(raw), + Language: workspace.LanguagePython, + }) + if err != nil { + return "", fmt.Errorf("unable to import %s: %w", remotePath, err) + } + + return remotePath, nil +} diff --git a/bundle/libraries/libraries.go b/bundle/libraries/libraries.go index f7a2574a..0afaf6d4 100644 --- a/bundle/libraries/libraries.go +++ b/bundle/libraries/libraries.go @@ -24,26 +24,48 @@ func (a *match) Name() string { } func (a *match) Apply(ctx context.Context, b *bundle.Bundle) error { - r := b.Config.Resources - for k := range b.Config.Resources.Jobs { - tasks := r.Jobs[k].JobSettings.Tasks - for i := range tasks { - task := &tasks[i] - if isMissingRequiredLibraries(task) { - return fmt.Errorf("task '%s' is missing required libraries. Please include your package code in task libraries block", task.TaskKey) - } - for j := range task.Libraries { - lib := &task.Libraries[j] - err := findArtifactsAndMarkForUpload(ctx, lib, b) - if err != nil { - return err - } + tasks := findAllTasks(b) + for _, task := range tasks { + if isMissingRequiredLibraries(task) { + return fmt.Errorf("task '%s' is missing required libraries. Please include your package code in task libraries block", task.TaskKey) + } + for j := range task.Libraries { + lib := &task.Libraries[j] + err := findArtifactsAndMarkForUpload(ctx, lib, b) + if err != nil { + return err } } } return nil } +func findAllTasks(b *bundle.Bundle) []*jobs.Task { + r := b.Config.Resources + result := make([]*jobs.Task, 0) + for k := range b.Config.Resources.Jobs { + tasks := r.Jobs[k].JobSettings.Tasks + for i := range tasks { + task := &tasks[i] + result = append(result, task) + } + } + + return result +} + +func FindAllWheelTasks(b *bundle.Bundle) []*jobs.Task { + tasks := findAllTasks(b) + wheelTasks := make([]*jobs.Task, 0) + for _, task := range tasks { + if task.PythonWheelTask != nil { + wheelTasks = append(wheelTasks, task) + } + } + + return wheelTasks +} + func isMissingRequiredLibraries(task *jobs.Task) bool { if task.Libraries != nil { return false diff --git a/bundle/phases/deploy.go b/bundle/phases/deploy.go index 011bb4b2..9d9b746e 100644 --- a/bundle/phases/deploy.go +++ b/bundle/phases/deploy.go @@ -8,6 +8,7 @@ import ( "github.com/databricks/cli/bundle/deploy/lock" "github.com/databricks/cli/bundle/deploy/terraform" "github.com/databricks/cli/bundle/libraries" + "github.com/databricks/cli/bundle/python" ) // The deploy phase deploys artifacts and resources. @@ -21,6 +22,7 @@ func Deploy() bundle.Mutator { libraries.MatchWithArtifacts(), artifacts.CleanUp(), artifacts.UploadAll(), + python.TransformWheelTask(), terraform.Interpolate(), terraform.Write(), terraform.StatePull(), diff --git a/bundle/python/transform.go b/bundle/python/transform.go new file mode 100644 index 00000000..401bce7f --- /dev/null +++ b/bundle/python/transform.go @@ -0,0 +1,106 @@ +package python + +import ( + "bytes" + "context" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/artifacts" + "github.com/databricks/cli/bundle/libraries" + "github.com/databricks/databricks-sdk-go/service/compute" + "github.com/databricks/databricks-sdk-go/service/jobs" +) + +// This mutator takes the wheel task and trasnforms it into notebook +// which installs uploaded wheels using %pip and then calling corresponding +// entry point. +func TransformWheelTask() bundle.Mutator { + return &transform{} +} + +type transform struct { +} + +func (m *transform) Name() string { + return "python.TransformWheelTask" +} + +const INSTALL_WHEEL_CODE = `%%pip install --force-reinstall %s` + +const NOTEBOOK_CODE = ` +%%python +%s + +from contextlib import redirect_stdout +import io +import sys +sys.argv = [%s] + +import pkg_resources +_func = pkg_resources.load_entry_point("%s", "console_scripts", "%s") + +f = io.StringIO() +with redirect_stdout(f): + _func() +s = f.getvalue() +dbutils.notebook.exit(s) +` + +func (m *transform) Apply(ctx context.Context, b *bundle.Bundle) error { + // TODO: do the transformaton only for DBR < 13.1 and (maybe?) existing clusters + wheelTasks := libraries.FindAllWheelTasks(b) + for _, wheelTask := range wheelTasks { + taskDefinition := wheelTask.PythonWheelTask + libraries := wheelTask.Libraries + + wheelTask.PythonWheelTask = nil + wheelTask.Libraries = nil + + path, err := generateNotebookWrapper(taskDefinition, libraries) + if err != nil { + return err + } + + remotePath, err := artifacts.UploadNotebook(context.Background(), path, b) + if err != nil { + return err + } + + os.Remove(path) + + wheelTask.NotebookTask = &jobs.NotebookTask{ + NotebookPath: remotePath, + } + } + return nil +} + +func generateNotebookWrapper(task *jobs.PythonWheelTask, libraries []compute.Library) (string, error) { + pipInstall := "" + for _, lib := range libraries { + pipInstall = pipInstall + "\n" + fmt.Sprintf(INSTALL_WHEEL_CODE, lib.Whl) + } + content := fmt.Sprintf(NOTEBOOK_CODE, pipInstall, generateParameters(task), task.PackageName, task.EntryPoint) + + tmpDir := os.TempDir() + filename := fmt.Sprintf("notebook_%s_%s.ipynb", task.PackageName, task.EntryPoint) + path := filepath.Join(tmpDir, filename) + + err := os.WriteFile(path, bytes.NewBufferString(content).Bytes(), 0644) + return path, err +} + +func generateParameters(task *jobs.PythonWheelTask) string { + params := append([]string{"python"}, task.Parameters...) + for k, v := range task.NamedParameters { + params = append(params, fmt.Sprintf("%s=%s", k, v)) + } + for i := range params { + params[i] = `"` + params[i] + `"` + } + return strings.Join(params, ", ") +}