address feedback

This commit is contained in:
kartikgupta-db 2023-09-13 12:50:29 +02:00
parent 6c16dc2bef
commit 8ae4bde773
No known key found for this signature in database
GPG Key ID: 6AD5FA11FACDEA39
5 changed files with 106 additions and 45 deletions

View File

@ -9,17 +9,13 @@ import (
"text/template"
"github.com/databricks/cli/bundle"
jobs_utils "github.com/databricks/cli/libs/jobs"
"github.com/databricks/databricks-sdk-go/service/jobs"
)
type TaskWithJobKey struct {
Task *jobs.Task
JobKey string
}
type TrampolineFunctions interface {
GetTemplateData(b *bundle.Bundle, task *jobs.Task) (map[string]any, error)
GetTasks(b *bundle.Bundle) []TaskWithJobKey
GetTasks(b *bundle.Bundle) []jobs_utils.TaskWithJobKey
GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error)
CleanUp(task *jobs.Task) error
}
@ -35,22 +31,6 @@ func NewTrampoline(
return &trampoline{name, functions}
}
func GetTasksWithJobKeyBy(b *bundle.Bundle, filter func(*jobs.Task) bool) []TaskWithJobKey {
tasks := make([]TaskWithJobKey, 0)
for k := range b.Config.Resources.Jobs {
for i := range b.Config.Resources.Jobs[k].Tasks {
t := &b.Config.Resources.Jobs[k].Tasks[i]
if filter(t) {
tasks = append(tasks, TaskWithJobKey{
JobKey: k,
Task: t,
})
}
}
}
return tasks
}
func (m *trampoline) Name() string {
return fmt.Sprintf("trampoline(%s)", m.name)
}
@ -66,7 +46,7 @@ func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error {
return nil
}
func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task TaskWithJobKey) error {
func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task jobs_utils.TaskWithJobKey) error {
internalDir, err := b.InternalDir()
if err != nil {
return err

View File

@ -10,6 +10,7 @@ import (
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config/mutator"
jobs_utils "github.com/databricks/cli/libs/jobs"
"github.com/databricks/databricks-sdk-go/service/jobs"
)
@ -44,20 +45,21 @@ func localNotebookPath(b *bundle.Bundle, task *jobs.Task) (string, error) {
if err == nil {
return fmt.Sprintf("%s.py", localPath), nil
}
return "", fmt.Errorf("notebook %s not found", localPath)
}
func (n *notebookTrampoline) GetTasks(b *bundle.Bundle) []mutator.TaskWithJobKey {
return mutator.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool {
func (n *notebookTrampoline) GetTasks(b *bundle.Bundle) []jobs_utils.TaskWithJobKey {
return jobs_utils.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool {
if task.NotebookTask == nil ||
task.NotebookTask.Source == jobs.SourceGit {
return false
}
localPath, err := localNotebookPath(b, task)
if err != nil {
return false
}
return strings.HasSuffix(localPath, ".ipynb") || strings.HasSuffix(localPath, ".py")
_, err := localNotebookPath(b, task)
// We assume if the notebook is not available locally in the bundle
// then the user has it somewhere in the workspace. For these
// out of bundle notebooks we do not want to write a trampoline.
return err == nil
})
}
@ -66,13 +68,6 @@ func (n *notebookTrampoline) CleanUp(task *jobs.Task) error {
}
func (n *notebookTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error) {
if task.NotebookTask == nil {
return "", fmt.Errorf("nil notebook path")
}
if task.NotebookTask.Source == jobs.SourceGit {
return "", fmt.Errorf("source must be workspace %s", task.NotebookTask.Source)
}
localPath, err := localNotebookPath(b, task)
if err != nil {
return "", err
@ -89,14 +84,10 @@ func (n *notebookTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (str
lines := strings.Split(s, "\n")
if strings.HasPrefix(lines[0], "# Databricks notebook source") {
return getDbnbTemplate(strings.Join(lines, "\n"))
return getDbnbTemplate(s)
}
return getPyTemplate(s), nil
}
func getPyTemplate(s string) string {
return pyTrampolineData
return pyTrampolineData, nil
}
func getDbnbTemplate(s string) (string, error) {

View File

@ -7,6 +7,7 @@ import (
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config/mutator"
jobs_utils "github.com/databricks/cli/libs/jobs"
"github.com/databricks/databricks-sdk-go/service/jobs"
)
@ -67,8 +68,8 @@ func (t *pythonTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (strin
return NOTEBOOK_TEMPLATE, nil
}
func (t *pythonTrampoline) GetTasks(b *bundle.Bundle) []mutator.TaskWithJobKey {
return mutator.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool {
func (t *pythonTrampoline) GetTasks(b *bundle.Bundle) []jobs_utils.TaskWithJobKey {
return jobs_utils.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool {
return task.PythonWheelTask != nil
})
}

27
libs/jobs/utils.go Normal file
View File

@ -0,0 +1,27 @@
package jobs_utils
import (
"github.com/databricks/cli/bundle"
"github.com/databricks/databricks-sdk-go/service/jobs"
)
type TaskWithJobKey struct {
Task *jobs.Task
JobKey string
}
func GetTasksWithJobKeyBy(b *bundle.Bundle, filter func(*jobs.Task) bool) []TaskWithJobKey {
tasks := make([]TaskWithJobKey, 0)
for k := range b.Config.Resources.Jobs {
for i := range b.Config.Resources.Jobs[k].Tasks {
t := &b.Config.Resources.Jobs[k].Tasks[i]
if filter(t) {
tasks = append(tasks, TaskWithJobKey{
JobKey: k,
Task: t,
})
}
}
}
return tasks
}

62
libs/jobs/utils_test.go Normal file
View File

@ -0,0 +1,62 @@
package jobs_utils
import (
"testing"
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/require"
)
func TestCorrectlyFilterTasksByFn(t *testing.T) {
bundle := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"job1": {
JobSettings: &jobs.JobSettings{
Tasks: []jobs.Task{
{
TaskKey: "job1_key1",
PythonWheelTask: &jobs.PythonWheelTask{},
},
{
TaskKey: "job1_key2",
NotebookTask: &jobs.NotebookTask{},
},
},
},
},
"job2": {
JobSettings: &jobs.JobSettings{
Tasks: []jobs.Task{
{
TaskKey: "job1_key1",
PythonWheelTask: &jobs.PythonWheelTask{},
},
{
TaskKey: "job2_key2",
NotebookTask: &jobs.NotebookTask{},
},
},
},
},
},
},
},
}
tasks := GetTasksWithJobKeyBy(bundle, func(task *jobs.Task) bool {
return task.PythonWheelTask != nil
})
require.Len(t, tasks, 2)
require.Equal(t, "job1", tasks[0].JobKey)
require.Equal(t, "job1_key1", tasks[0].Task.TaskKey)
require.Equal(t, "job2", tasks[1].JobKey)
require.Equal(t, "job1_key1", tasks[1].Task.TaskKey)
}