databricks-cli/bundle/trampoline/python_wheel.go

163 lines
3.9 KiB
Go

package trampoline
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config/mutator"
"github.com/databricks/cli/bundle/libraries"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/jobs"
)
const NOTEBOOK_TEMPLATE = `# Databricks notebook source
%python
{{range .Libraries}}
%pip install --force-reinstall {{.Whl}}
{{end}}
dbutils.library.restartPython()
try:
from importlib import metadata
except ImportError: # for Python<3.8
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "importlib-metadata"])
import importlib_metadata as metadata
from contextlib import redirect_stdout
import io
import sys
import json
params = []
try:
python_params = dbutils.widgets.get("__python_params")
if python_params:
params = json.loads(python_params)
except Exception as e:
print(e)
sys.argv = [{{.Params}}]
if params:
sys.argv = [sys.argv[0]] + params
entry = [ep for ep in metadata.distribution("{{.Task.PackageName}}").entry_points if ep.name == "{{.Task.EntryPoint}}"]
f = io.StringIO()
with redirect_stdout(f):
if entry:
entry[0].load()()
else:
raise ImportError("Entry point '{{.Task.EntryPoint}}' not found")
s = f.getvalue()
dbutils.notebook.exit(s)
`
// This mutator takes the wheel task and transforms it into notebook
// which installs uploaded wheels using %pip and then calling corresponding
// entry point.
func TransformWheelTask() bundle.Mutator {
return bundle.If(
func(_ context.Context, b *bundle.Bundle) (bool, error) {
res := b.Config.Experimental != nil && b.Config.Experimental.PythonWheelWrapper
return res, nil
},
NewTrampoline(
"python_wheel",
&pythonTrampoline{},
NOTEBOOK_TEMPLATE,
),
mutator.NoOp(),
)
}
type pythonTrampoline struct{}
func (t *pythonTrampoline) CleanUp(task *jobs.Task) error {
task.PythonWheelTask = nil
nonWheelLibraries := make([]compute.Library, 0)
for _, l := range task.Libraries {
if l.Whl == "" {
nonWheelLibraries = append(nonWheelLibraries, l)
}
}
task.Libraries = nonWheelLibraries
return nil
}
func (t *pythonTrampoline) GetTasks(b *bundle.Bundle) []TaskWithJobKey {
r := b.Config.Resources
result := make([]TaskWithJobKey, 0)
for k := range b.Config.Resources.Jobs {
tasks := r.Jobs[k].JobSettings.Tasks
for i := range tasks {
task := &tasks[i]
// Keep only Python wheel tasks with workspace libraries referenced.
// At this point of moment we don't have local paths in Libraries sections anymore
// Local paths have been replaced with the remote when the artifacts where uploaded
// in artifacts.UploadAll mutator.
if task.PythonWheelTask == nil || !needsTrampoline(*task) {
continue
}
result = append(result, TaskWithJobKey{
JobKey: k,
Task: task,
})
}
}
return result
}
func needsTrampoline(task jobs.Task) bool {
return libraries.IsTaskWithWorkspaceLibraries(task)
}
func (t *pythonTrampoline) GetTemplateData(task *jobs.Task) (map[string]any, error) {
params, err := t.generateParameters(task.PythonWheelTask)
whlLibraries := make([]compute.Library, 0)
for _, l := range task.Libraries {
if l.Whl != "" {
whlLibraries = append(whlLibraries, l)
}
}
if err != nil {
return nil, err
}
data := map[string]any{
"Libraries": whlLibraries,
"Params": params,
"Task": task.PythonWheelTask,
}
return data, nil
}
func (t *pythonTrampoline) generateParameters(task *jobs.PythonWheelTask) (string, error) {
if task.Parameters != nil && task.NamedParameters != nil {
return "", errors.New("not allowed to pass both paramaters and named_parameters")
}
params := append([]string{task.PackageName}, task.Parameters...)
for k, v := range task.NamedParameters {
params = append(params, fmt.Sprintf("%s=%s", k, v))
}
for i := range params {
params[i] = strconv.Quote(params[i])
}
return strings.Join(params, ", "), nil
}