Enable environment overrides for job tasks (#779)

## Changes
Follow up for https://github.com/databricks/cli/pull/658

When a job definition has multiple job tasks using the same key, it's
considered invalid. Instead we should combine those definitions with the
same key into one. This is consistent with environment overrides. This
way, the override ends up in the original job tasks, and we've got a
clear way to put them all together.

## Tests
Added unit tests
This commit is contained in:
Andrew Nester 2023-09-18 16:13:50 +02:00 committed by GitHub
parent b3b00fd226
commit 43e2eefc27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 179 additions and 0 deletions

View File

@ -141,3 +141,14 @@ func (r *Resources) MergeJobClusters() error {
}
return nil
}
// MergeTasks iterates over all jobs and merges their tasks.
// This is called after applying the target overrides.
func (r *Resources) MergeTasks() error {
for _, job := range r.Jobs {
if err := job.MergeTasks(); err != nil {
return err
}
}
return nil
}

View File

@ -47,3 +47,36 @@ func (j *Job) MergeJobClusters() error {
j.JobClusters = output
return nil
}
// MergeTasks merges tasks with the same key.
// The tasks field is a slice, and as such, overrides are appended to it.
// We can identify a task by its task key, however, so we can use this key
// to figure out which definitions are actually overrides and merge them.
func (j *Job) MergeTasks() error {
keys := make(map[string]*jobs.Task)
tasks := make([]jobs.Task, 0, len(j.Tasks))
// Target overrides are always appended, so we can iterate in natural order to
// first find the base definition, and merge instances we encounter later.
for i := range j.Tasks {
key := j.Tasks[i].TaskKey
// Register the task with key if not yet seen before.
ref, ok := keys[key]
if !ok {
tasks = append(tasks, j.Tasks[i])
keys[key] = &j.Tasks[i]
continue
}
// Merge this instance into the reference.
err := mergo.Merge(ref, &j.Tasks[i], mergo.WithOverride, mergo.WithAppendSlice)
if err != nil {
return err
}
}
// Overwrite resulting slice.
j.Tasks = tasks
return nil
}

View File

@ -55,3 +55,50 @@ func TestJobMergeJobClusters(t *testing.T) {
jc1 := j.JobClusters[1].NewCluster
assert.Equal(t, "10.4.x-scala2.12", jc1.SparkVersion)
}
func TestJobMergeTasks(t *testing.T) {
j := &Job{
JobSettings: &jobs.JobSettings{
Tasks: []jobs.Task{
{
TaskKey: "foo",
NewCluster: &compute.ClusterSpec{
SparkVersion: "13.3.x-scala2.12",
NodeTypeId: "i3.xlarge",
NumWorkers: 2,
},
},
{
TaskKey: "bar",
NewCluster: &compute.ClusterSpec{
SparkVersion: "10.4.x-scala2.12",
},
},
{
TaskKey: "foo",
NewCluster: &compute.ClusterSpec{
NodeTypeId: "i3.2xlarge",
NumWorkers: 4,
},
},
},
},
}
err := j.MergeTasks()
require.NoError(t, err)
assert.Len(t, j.Tasks, 2)
assert.Equal(t, "foo", j.Tasks[0].TaskKey)
assert.Equal(t, "bar", j.Tasks[1].TaskKey)
// This task was merged with a subsequent one.
task0 := j.Tasks[0].NewCluster
assert.Equal(t, "13.3.x-scala2.12", task0.SparkVersion)
assert.Equal(t, "i3.2xlarge", task0.NodeTypeId)
assert.Equal(t, 4, task0.NumWorkers)
// This task was left untouched.
task1 := j.Tasks[1].NewCluster
assert.Equal(t, "10.4.x-scala2.12", task1.SparkVersion)
}

View File

@ -242,6 +242,11 @@ func (r *Root) MergeTargetOverrides(target *Target) error {
if err != nil {
return err
}
err = r.Resources.MergeTasks()
if err != nil {
return err
}
}
if target.Variables != nil {

View File

@ -0,0 +1,44 @@
bundle:
name: override_job_tasks
workspace:
host: https://acme.cloud.databricks.com/
resources:
jobs:
foo:
name: job
tasks:
- task_key: key1
new_cluster:
spark_version: 13.3.x-scala2.12
spark_python_task:
python_file: ./test1.py
- task_key: key2
new_cluster:
spark_version: 13.3.x-scala2.12
spark_python_task:
python_file: ./test2.py
targets:
development:
resources:
jobs:
foo:
tasks:
- task_key: key1
new_cluster:
node_type_id: i3.xlarge
num_workers: 1
staging:
resources:
jobs:
foo:
tasks:
- task_key: key2
new_cluster:
node_type_id: i3.2xlarge
num_workers: 4
spark_python_task:
python_file: ./test3.py

View File

@ -0,0 +1,39 @@
package config_tests
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestOverrideTasksDev(t *testing.T) {
b := loadTarget(t, "./override_job_tasks", "development")
assert.Equal(t, "job", b.Config.Resources.Jobs["foo"].Name)
assert.Len(t, b.Config.Resources.Jobs["foo"].Tasks, 2)
tasks := b.Config.Resources.Jobs["foo"].Tasks
assert.Equal(t, tasks[0].TaskKey, "key1")
assert.Equal(t, tasks[0].NewCluster.NodeTypeId, "i3.xlarge")
assert.Equal(t, tasks[0].NewCluster.NumWorkers, 1)
assert.Equal(t, tasks[0].SparkPythonTask.PythonFile, "./test1.py")
assert.Equal(t, tasks[1].TaskKey, "key2")
assert.Equal(t, tasks[1].NewCluster.SparkVersion, "13.3.x-scala2.12")
assert.Equal(t, tasks[1].SparkPythonTask.PythonFile, "./test2.py")
}
func TestOverrideTasksStaging(t *testing.T) {
b := loadTarget(t, "./override_job_tasks", "staging")
assert.Equal(t, "job", b.Config.Resources.Jobs["foo"].Name)
assert.Len(t, b.Config.Resources.Jobs["foo"].Tasks, 2)
tasks := b.Config.Resources.Jobs["foo"].Tasks
assert.Equal(t, tasks[0].TaskKey, "key1")
assert.Equal(t, tasks[0].NewCluster.SparkVersion, "13.3.x-scala2.12")
assert.Equal(t, tasks[0].SparkPythonTask.PythonFile, "./test1.py")
assert.Equal(t, tasks[1].TaskKey, "key2")
assert.Equal(t, tasks[1].NewCluster.NodeTypeId, "i3.2xlarge")
assert.Equal(t, tasks[1].NewCluster.NumWorkers, 4)
assert.Equal(t, tasks[1].SparkPythonTask.PythonFile, "./test3.py")
}