diff --git a/internal/bundle/bundles/spark_jar_task/databricks_template_schema.json b/internal/bundle/bundles/spark_jar_task/databricks_template_schema.json index 078dff97..1381da1d 100644 --- a/internal/bundle/bundles/spark_jar_task/databricks_template_schema.json +++ b/internal/bundle/bundles/spark_jar_task/databricks_template_schema.json @@ -24,6 +24,10 @@ "artifact_path": { "type": "string", "description": "Path to the remote base path for artifacts" + }, + "instance_pool_id": { + "type": "string", + "description": "Instance pool id for job cluster" } } } diff --git a/internal/bundle/bundles/spark_jar_task/template/databricks.yml.tmpl b/internal/bundle/bundles/spark_jar_task/template/databricks.yml.tmpl index 24a6d7d8..8c9331fe 100644 --- a/internal/bundle/bundles/spark_jar_task/template/databricks.yml.tmpl +++ b/internal/bundle/bundles/spark_jar_task/template/databricks.yml.tmpl @@ -22,6 +22,7 @@ resources: num_workers: 1 spark_version: "{{.spark_version}}" node_type_id: "{{.node_type_id}}" + instance_pool_id: "{{.instance_pool_id}}" spark_jar_task: main_class_name: PrintArgs libraries: diff --git a/internal/bundle/spark_jar_test.go b/internal/bundle/spark_jar_test.go index c981e775..98bfa4a9 100644 --- a/internal/bundle/spark_jar_test.go +++ b/internal/bundle/spark_jar_test.go @@ -6,15 +6,14 @@ import ( "github.com/databricks/cli/internal" "github.com/databricks/cli/internal/acc" + "github.com/databricks/cli/libs/env" "github.com/google/uuid" "github.com/stretchr/testify/require" ) func runSparkJarTest(t *testing.T, sparkVersion string) { - t.Skip("Temporarily skipping the test until auth / permission issues for UC volumes are resolved.") - - env := internal.GetEnvOrSkipTest(t, "CLOUD_ENV") - t.Log(env) + cloudEnv := internal.GetEnvOrSkipTest(t, "CLOUD_ENV") + t.Log(cloudEnv) if os.Getenv("TEST_METASTORE_ID") == "" { t.Skip("Skipping tests that require a UC Volume when metastore id is not set.") @@ -24,14 +23,16 @@ func runSparkJarTest(t *testing.T, sparkVersion string) { w := wt.W volumePath := internal.TemporaryUcVolume(t, w) - nodeTypeId := internal.GetNodeTypeId(env) + nodeTypeId := internal.GetNodeTypeId(cloudEnv) tmpDir := t.TempDir() + instancePoolId := env.Get(ctx, "TEST_INSTANCE_POOL_ID") bundleRoot, err := initTestTemplateWithBundleRoot(t, ctx, "spark_jar_task", map[string]any{ - "node_type_id": nodeTypeId, - "unique_id": uuid.New().String(), - "spark_version": sparkVersion, - "root": tmpDir, - "artifact_path": volumePath, + "node_type_id": nodeTypeId, + "unique_id": uuid.New().String(), + "spark_version": sparkVersion, + "root": tmpDir, + "artifact_path": volumePath, + "instance_pool_id": instancePoolId, }, tmpDir) require.NoError(t, err)