mirror of https://github.com/databricks/cli.git
address feedback
This commit is contained in:
parent
6c16dc2bef
commit
8ae4bde773
|
@ -9,17 +9,13 @@ import (
|
|||
"text/template"
|
||||
|
||||
"github.com/databricks/cli/bundle"
|
||||
jobs_utils "github.com/databricks/cli/libs/jobs"
|
||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||
)
|
||||
|
||||
type TaskWithJobKey struct {
|
||||
Task *jobs.Task
|
||||
JobKey string
|
||||
}
|
||||
|
||||
type TrampolineFunctions interface {
|
||||
GetTemplateData(b *bundle.Bundle, task *jobs.Task) (map[string]any, error)
|
||||
GetTasks(b *bundle.Bundle) []TaskWithJobKey
|
||||
GetTasks(b *bundle.Bundle) []jobs_utils.TaskWithJobKey
|
||||
GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error)
|
||||
CleanUp(task *jobs.Task) error
|
||||
}
|
||||
|
@ -35,22 +31,6 @@ func NewTrampoline(
|
|||
return &trampoline{name, functions}
|
||||
}
|
||||
|
||||
func GetTasksWithJobKeyBy(b *bundle.Bundle, filter func(*jobs.Task) bool) []TaskWithJobKey {
|
||||
tasks := make([]TaskWithJobKey, 0)
|
||||
for k := range b.Config.Resources.Jobs {
|
||||
for i := range b.Config.Resources.Jobs[k].Tasks {
|
||||
t := &b.Config.Resources.Jobs[k].Tasks[i]
|
||||
if filter(t) {
|
||||
tasks = append(tasks, TaskWithJobKey{
|
||||
JobKey: k,
|
||||
Task: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return tasks
|
||||
}
|
||||
|
||||
func (m *trampoline) Name() string {
|
||||
return fmt.Sprintf("trampoline(%s)", m.name)
|
||||
}
|
||||
|
@ -66,7 +46,7 @@ func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task TaskWithJobKey) error {
|
||||
func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task jobs_utils.TaskWithJobKey) error {
|
||||
internalDir, err := b.InternalDir()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/databricks/cli/bundle"
|
||||
"github.com/databricks/cli/bundle/config/mutator"
|
||||
jobs_utils "github.com/databricks/cli/libs/jobs"
|
||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||
)
|
||||
|
||||
|
@ -44,20 +45,21 @@ func localNotebookPath(b *bundle.Bundle, task *jobs.Task) (string, error) {
|
|||
if err == nil {
|
||||
return fmt.Sprintf("%s.py", localPath), nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("notebook %s not found", localPath)
|
||||
}
|
||||
|
||||
func (n *notebookTrampoline) GetTasks(b *bundle.Bundle) []mutator.TaskWithJobKey {
|
||||
return mutator.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool {
|
||||
func (n *notebookTrampoline) GetTasks(b *bundle.Bundle) []jobs_utils.TaskWithJobKey {
|
||||
return jobs_utils.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool {
|
||||
if task.NotebookTask == nil ||
|
||||
task.NotebookTask.Source == jobs.SourceGit {
|
||||
return false
|
||||
}
|
||||
localPath, err := localNotebookPath(b, task)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.HasSuffix(localPath, ".ipynb") || strings.HasSuffix(localPath, ".py")
|
||||
_, err := localNotebookPath(b, task)
|
||||
// We assume if the notebook is not available locally in the bundle
|
||||
// then the user has it somewhere in the workspace. For these
|
||||
// out of bundle notebooks we do not want to write a trampoline.
|
||||
return err == nil
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -66,13 +68,6 @@ func (n *notebookTrampoline) CleanUp(task *jobs.Task) error {
|
|||
}
|
||||
|
||||
func (n *notebookTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (string, error) {
|
||||
if task.NotebookTask == nil {
|
||||
return "", fmt.Errorf("nil notebook path")
|
||||
}
|
||||
|
||||
if task.NotebookTask.Source == jobs.SourceGit {
|
||||
return "", fmt.Errorf("source must be workspace %s", task.NotebookTask.Source)
|
||||
}
|
||||
localPath, err := localNotebookPath(b, task)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -89,14 +84,10 @@ func (n *notebookTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (str
|
|||
|
||||
lines := strings.Split(s, "\n")
|
||||
if strings.HasPrefix(lines[0], "# Databricks notebook source") {
|
||||
return getDbnbTemplate(strings.Join(lines, "\n"))
|
||||
return getDbnbTemplate(s)
|
||||
}
|
||||
|
||||
return getPyTemplate(s), nil
|
||||
}
|
||||
|
||||
func getPyTemplate(s string) string {
|
||||
return pyTrampolineData
|
||||
return pyTrampolineData, nil
|
||||
}
|
||||
|
||||
func getDbnbTemplate(s string) (string, error) {
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/databricks/cli/bundle"
|
||||
"github.com/databricks/cli/bundle/config/mutator"
|
||||
jobs_utils "github.com/databricks/cli/libs/jobs"
|
||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||
)
|
||||
|
||||
|
@ -67,8 +68,8 @@ func (t *pythonTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (strin
|
|||
return NOTEBOOK_TEMPLATE, nil
|
||||
}
|
||||
|
||||
func (t *pythonTrampoline) GetTasks(b *bundle.Bundle) []mutator.TaskWithJobKey {
|
||||
return mutator.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool {
|
||||
func (t *pythonTrampoline) GetTasks(b *bundle.Bundle) []jobs_utils.TaskWithJobKey {
|
||||
return jobs_utils.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool {
|
||||
return task.PythonWheelTask != nil
|
||||
})
|
||||
}
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
package jobs_utils
|
||||
|
||||
import (
|
||||
"github.com/databricks/cli/bundle"
|
||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||
)
|
||||
|
||||
type TaskWithJobKey struct {
|
||||
Task *jobs.Task
|
||||
JobKey string
|
||||
}
|
||||
|
||||
func GetTasksWithJobKeyBy(b *bundle.Bundle, filter func(*jobs.Task) bool) []TaskWithJobKey {
|
||||
tasks := make([]TaskWithJobKey, 0)
|
||||
for k := range b.Config.Resources.Jobs {
|
||||
for i := range b.Config.Resources.Jobs[k].Tasks {
|
||||
t := &b.Config.Resources.Jobs[k].Tasks[i]
|
||||
if filter(t) {
|
||||
tasks = append(tasks, TaskWithJobKey{
|
||||
JobKey: k,
|
||||
Task: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return tasks
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
package jobs_utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/databricks/cli/bundle"
|
||||
"github.com/databricks/cli/bundle/config"
|
||||
"github.com/databricks/cli/bundle/config/resources"
|
||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCorrectlyFilterTasksByFn(t *testing.T) {
|
||||
bundle := &bundle.Bundle{
|
||||
Config: config.Root{
|
||||
Resources: config.Resources{
|
||||
Jobs: map[string]*resources.Job{
|
||||
"job1": {
|
||||
JobSettings: &jobs.JobSettings{
|
||||
Tasks: []jobs.Task{
|
||||
{
|
||||
TaskKey: "job1_key1",
|
||||
PythonWheelTask: &jobs.PythonWheelTask{},
|
||||
},
|
||||
{
|
||||
TaskKey: "job1_key2",
|
||||
NotebookTask: &jobs.NotebookTask{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"job2": {
|
||||
JobSettings: &jobs.JobSettings{
|
||||
Tasks: []jobs.Task{
|
||||
{
|
||||
TaskKey: "job1_key1",
|
||||
PythonWheelTask: &jobs.PythonWheelTask{},
|
||||
},
|
||||
{
|
||||
TaskKey: "job2_key2",
|
||||
NotebookTask: &jobs.NotebookTask{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tasks := GetTasksWithJobKeyBy(bundle, func(task *jobs.Task) bool {
|
||||
return task.PythonWheelTask != nil
|
||||
})
|
||||
|
||||
require.Len(t, tasks, 2)
|
||||
|
||||
require.Equal(t, "job1", tasks[0].JobKey)
|
||||
require.Equal(t, "job1_key1", tasks[0].Task.TaskKey)
|
||||
|
||||
require.Equal(t, "job2", tasks[1].JobKey)
|
||||
require.Equal(t, "job1_key1", tasks[1].Task.TaskKey)
|
||||
}
|
Loading…
Reference in New Issue