mirror of https://github.com/databricks/cli.git
address feedback
This commit is contained in:
parent
6c16dc2bef
commit
8ae4bde773
|
@ -9,17 +9,13 @@ import (
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
"github.com/databricks/cli/bundle"
|
"github.com/databricks/cli/bundle"
|
||||||
|
jobs_utils "github.com/databricks/cli/libs/jobs"
|
||||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TaskWithJobKey struct {
|
|
||||||
Task *jobs.Task
|
|
||||||
JobKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
type TrampolineFunctions interface {
|
type TrampolineFunctions interface {
|
||||||
GetTemplateData(b *bundle.Bundle, task *jobs.Task) (map[string]any, error)
|
GetTemplateData(b *bundle.Bundle, task *jobs.Task) (map[string]any, error)
|
||||||
GetTasks(b *bundle.Bundle) []TaskWithJobKey
|
GetTasks(b *bundle.Bundle) []jobs_utils.TaskWithJobKey
|
||||||
GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error)
|
GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error)
|
||||||
CleanUp(task *jobs.Task) error
|
CleanUp(task *jobs.Task) error
|
||||||
}
|
}
|
||||||
|
@ -35,22 +31,6 @@ func NewTrampoline(
|
||||||
return &trampoline{name, functions}
|
return &trampoline{name, functions}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTasksWithJobKeyBy(b *bundle.Bundle, filter func(*jobs.Task) bool) []TaskWithJobKey {
|
|
||||||
tasks := make([]TaskWithJobKey, 0)
|
|
||||||
for k := range b.Config.Resources.Jobs {
|
|
||||||
for i := range b.Config.Resources.Jobs[k].Tasks {
|
|
||||||
t := &b.Config.Resources.Jobs[k].Tasks[i]
|
|
||||||
if filter(t) {
|
|
||||||
tasks = append(tasks, TaskWithJobKey{
|
|
||||||
JobKey: k,
|
|
||||||
Task: t,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return tasks
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *trampoline) Name() string {
|
func (m *trampoline) Name() string {
|
||||||
return fmt.Sprintf("trampoline(%s)", m.name)
|
return fmt.Sprintf("trampoline(%s)", m.name)
|
||||||
}
|
}
|
||||||
|
@ -66,7 +46,7 @@ func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task TaskWithJobKey) error {
|
func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task jobs_utils.TaskWithJobKey) error {
|
||||||
internalDir, err := b.InternalDir()
|
internalDir, err := b.InternalDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/databricks/cli/bundle"
|
"github.com/databricks/cli/bundle"
|
||||||
"github.com/databricks/cli/bundle/config/mutator"
|
"github.com/databricks/cli/bundle/config/mutator"
|
||||||
|
jobs_utils "github.com/databricks/cli/libs/jobs"
|
||||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -44,20 +45,21 @@ func localNotebookPath(b *bundle.Bundle, task *jobs.Task) (string, error) {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return fmt.Sprintf("%s.py", localPath), nil
|
return fmt.Sprintf("%s.py", localPath), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", fmt.Errorf("notebook %s not found", localPath)
|
return "", fmt.Errorf("notebook %s not found", localPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *notebookTrampoline) GetTasks(b *bundle.Bundle) []mutator.TaskWithJobKey {
|
func (n *notebookTrampoline) GetTasks(b *bundle.Bundle) []jobs_utils.TaskWithJobKey {
|
||||||
return mutator.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool {
|
return jobs_utils.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool {
|
||||||
if task.NotebookTask == nil ||
|
if task.NotebookTask == nil ||
|
||||||
task.NotebookTask.Source == jobs.SourceGit {
|
task.NotebookTask.Source == jobs.SourceGit {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
localPath, err := localNotebookPath(b, task)
|
_, err := localNotebookPath(b, task)
|
||||||
if err != nil {
|
// We assume if the notebook is not available locally in the bundle
|
||||||
return false
|
// then the user has it somewhere in the workspace. For these
|
||||||
}
|
// out of bundle notebooks we do not want to write a trampoline.
|
||||||
return strings.HasSuffix(localPath, ".ipynb") || strings.HasSuffix(localPath, ".py")
|
return err == nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,13 +68,6 @@ func (n *notebookTrampoline) CleanUp(task *jobs.Task) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *notebookTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error) {
|
func (n *notebookTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error) {
|
||||||
if task.NotebookTask == nil {
|
|
||||||
return "", fmt.Errorf("nil notebook path")
|
|
||||||
}
|
|
||||||
|
|
||||||
if task.NotebookTask.Source == jobs.SourceGit {
|
|
||||||
return "", fmt.Errorf("source must be workspace %s", task.NotebookTask.Source)
|
|
||||||
}
|
|
||||||
localPath, err := localNotebookPath(b, task)
|
localPath, err := localNotebookPath(b, task)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -89,14 +84,10 @@ func (n *notebookTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (str
|
||||||
|
|
||||||
lines := strings.Split(s, "\n")
|
lines := strings.Split(s, "\n")
|
||||||
if strings.HasPrefix(lines[0], "# Databricks notebook source") {
|
if strings.HasPrefix(lines[0], "# Databricks notebook source") {
|
||||||
return getDbnbTemplate(strings.Join(lines, "\n"))
|
return getDbnbTemplate(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
return getPyTemplate(s), nil
|
return pyTrampolineData, nil
|
||||||
}
|
|
||||||
|
|
||||||
func getPyTemplate(s string) string {
|
|
||||||
return pyTrampolineData
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDbnbTemplate(s string) (string, error) {
|
func getDbnbTemplate(s string) (string, error) {
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
|
|
||||||
"github.com/databricks/cli/bundle"
|
"github.com/databricks/cli/bundle"
|
||||||
"github.com/databricks/cli/bundle/config/mutator"
|
"github.com/databricks/cli/bundle/config/mutator"
|
||||||
|
jobs_utils "github.com/databricks/cli/libs/jobs"
|
||||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -67,8 +68,8 @@ func (t *pythonTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (strin
|
||||||
return NOTEBOOK_TEMPLATE, nil
|
return NOTEBOOK_TEMPLATE, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *pythonTrampoline) GetTasks(b *bundle.Bundle) []mutator.TaskWithJobKey {
|
func (t *pythonTrampoline) GetTasks(b *bundle.Bundle) []jobs_utils.TaskWithJobKey {
|
||||||
return mutator.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool {
|
return jobs_utils.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool {
|
||||||
return task.PythonWheelTask != nil
|
return task.PythonWheelTask != nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
package jobs_utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/databricks/cli/bundle"
|
||||||
|
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TaskWithJobKey struct {
|
||||||
|
Task *jobs.Task
|
||||||
|
JobKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetTasksWithJobKeyBy(b *bundle.Bundle, filter func(*jobs.Task) bool) []TaskWithJobKey {
|
||||||
|
tasks := make([]TaskWithJobKey, 0)
|
||||||
|
for k := range b.Config.Resources.Jobs {
|
||||||
|
for i := range b.Config.Resources.Jobs[k].Tasks {
|
||||||
|
t := &b.Config.Resources.Jobs[k].Tasks[i]
|
||||||
|
if filter(t) {
|
||||||
|
tasks = append(tasks, TaskWithJobKey{
|
||||||
|
JobKey: k,
|
||||||
|
Task: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tasks
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
package jobs_utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/databricks/cli/bundle"
|
||||||
|
"github.com/databricks/cli/bundle/config"
|
||||||
|
"github.com/databricks/cli/bundle/config/resources"
|
||||||
|
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCorrectlyFilterTasksByFn(t *testing.T) {
|
||||||
|
bundle := &bundle.Bundle{
|
||||||
|
Config: config.Root{
|
||||||
|
Resources: config.Resources{
|
||||||
|
Jobs: map[string]*resources.Job{
|
||||||
|
"job1": {
|
||||||
|
JobSettings: &jobs.JobSettings{
|
||||||
|
Tasks: []jobs.Task{
|
||||||
|
{
|
||||||
|
TaskKey: "job1_key1",
|
||||||
|
PythonWheelTask: &jobs.PythonWheelTask{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
TaskKey: "job1_key2",
|
||||||
|
NotebookTask: &jobs.NotebookTask{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"job2": {
|
||||||
|
JobSettings: &jobs.JobSettings{
|
||||||
|
Tasks: []jobs.Task{
|
||||||
|
{
|
||||||
|
TaskKey: "job1_key1",
|
||||||
|
PythonWheelTask: &jobs.PythonWheelTask{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
TaskKey: "job2_key2",
|
||||||
|
NotebookTask: &jobs.NotebookTask{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks := GetTasksWithJobKeyBy(bundle, func(task *jobs.Task) bool {
|
||||||
|
return task.PythonWheelTask != nil
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Len(t, tasks, 2)
|
||||||
|
|
||||||
|
require.Equal(t, "job1", tasks[0].JobKey)
|
||||||
|
require.Equal(t, "job1_key1", tasks[0].Task.TaskKey)
|
||||||
|
|
||||||
|
require.Equal(t, "job2", tasks[1].JobKey)
|
||||||
|
require.Equal(t, "job1_key1", tasks[1].Task.TaskKey)
|
||||||
|
}
|
Loading…
Reference in New Issue