diff --git a/bundle/config/mutator/apply_presets.go b/bundle/config/mutator/apply_presets.go index 1fd49206f..8e3a0baf3 100644 --- a/bundle/config/mutator/apply_presets.go +++ b/bundle/config/mutator/apply_presets.go @@ -2,6 +2,7 @@ package mutator import ( "context" + "fmt" "path" "slices" "sort" @@ -9,6 +10,7 @@ import ( "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/config/resources" "github.com/databricks/cli/libs/diag" "github.com/databricks/cli/libs/dyn" "github.com/databricks/cli/libs/textutil" @@ -37,6 +39,9 @@ func (m *applyPresets) Name() string { func (m *applyPresets) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics { var diags diag.Diagnostics + if d := validateCatalogAndSchema(b); d != nil { + return d // fast fail since code below would fail + } if d := validatePauseStatus(b); d != nil { diags = diags.Extend(d) } @@ -46,7 +51,7 @@ func (m *applyPresets) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnos prefix := t.NamePrefix tags := toTagArray(t.Tags) - // Jobs presets: Prefix, Tags, JobsMaxConcurrentRuns, TriggerPauseStatus + // Jobs presets: Prefix, Tags, JobsMaxConcurrentRuns, TriggerPauseStatus, Catalog, Schema for key, j := range r.Jobs { if j.JobSettings == nil { diags = diags.Extend(diag.Errorf("job %s is not defined", key)) @@ -80,9 +85,12 @@ func (m *applyPresets) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnos j.Trigger.PauseStatus = paused } } + if t.Catalog != "" || t.Schema != "" { + diags = diags.Extend(validateJobUsesCatalogAndSchema(b, key, j)) + } } - // Pipelines presets: Prefix, PipelinesDevelopment + // Pipelines presets: Prefix, PipelinesDevelopment, Catalog, Schema for key, p := range r.Pipelines { if p.PipelineSpec == nil { diags = diags.Extend(diag.Errorf("pipeline %s is not defined", key)) @@ -95,7 +103,13 @@ func (m *applyPresets) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnos if t.TriggerPauseStatus == config.Paused { p.Continuous = false } - // As of 2024-06, pipelines don't yet support tags + if t.Catalog != "" && p.Catalog == "" { + p.Catalog = t.Catalog + } + if t.Schema != "" && p.Target == "" { + p.Target = t.Schema + } + // As of 2024-10, pipelines don't yet support tags } // Models presets: Prefix, Tags @@ -155,18 +169,23 @@ func (m *applyPresets) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnos // As of 2024-06, model serving endpoints don't yet support tags } - // Registered models presets: Prefix + // Registered models presets: Prefix, Catalog, Schema for key, m := range r.RegisteredModels { if m.CreateRegisteredModelRequest == nil { diags = diags.Extend(diag.Errorf("registered model %s is not defined", key)) continue } m.Name = normalizePrefix(prefix) + m.Name - + if t.Catalog != "" && m.CatalogName == "" { + m.CatalogName = t.Catalog + } + if t.Schema != "" && m.SchemaName == "" { + m.SchemaName = t.Schema + } // As of 2024-06, registered models don't yet support tags } - // Quality monitors presets: Schedule + // Quality monitors presets: Schedule, Catalog, Schema if t.TriggerPauseStatus == config.Paused { for key, q := range r.QualityMonitors { if q.CreateMonitor == nil { @@ -179,16 +198,30 @@ func (m *applyPresets) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnos if q.Schedule != nil && q.Schedule.PauseStatus != catalog.MonitorCronSchedulePauseStatusUnpaused { q.Schedule = nil } + if t.Catalog != "" && t.Schema != "" { + parts := strings.Split(q.TableName, ".") + if len(parts) != 3 { + q.TableName = fmt.Sprintf("%s.%s.%s", t.Catalog, t.Schema, q.TableName) + } + } } } - // Schemas: Prefix + // Schemas: Prefix, Catalog, Schema for key, s := range r.Schemas { if s.CreateSchema == nil { diags = diags.Extend(diag.Errorf("schema %s is not defined", key)) continue } s.Name = normalizePrefix(prefix) + s.Name + if t.Catalog != "" && s.CatalogName == "" { + s.CatalogName = t.Catalog + } + if t.Schema != "" && s.Name == "" { + // If there is a schema preset such as 'dev', we directly + // use that name and don't add any prefix (which might result in dev_dev). + s.Name = t.Schema + } // HTTP API for schemas doesn't yet support tags. It's only supported in // the Databricks UI and via the SQL API. } @@ -204,10 +237,10 @@ func (m *applyPresets) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnos c.CustomTags = make(map[string]string) } for _, tag := range tags { - normalisedKey := b.Tagging.NormalizeKey(tag.Key) - normalisedValue := b.Tagging.NormalizeValue(tag.Value) - if _, ok := c.CustomTags[normalisedKey]; !ok { - c.CustomTags[normalisedKey] = normalisedValue + normalizedKey := b.Tagging.NormalizeKey(tag.Key) + normalizedValue := b.Tagging.NormalizeValue(tag.Value) + if _, ok := c.CustomTags[normalizedKey]; !ok { + c.CustomTags[normalizedKey] = normalizedValue } } } @@ -227,6 +260,46 @@ func validatePauseStatus(b *bundle.Bundle) diag.Diagnostics { }} } +func validateCatalogAndSchema(b *bundle.Bundle) diag.Diagnostics { + p := b.Config.Presets + if (p.Catalog != "" && p.Schema == "") || (p.Catalog == "" && p.Schema != "") { + return diag.Diagnostics{{ + Summary: "presets.catalog and presets.schema must always be set together", + Severity: diag.Error, + Locations: []dyn.Location{b.Config.GetLocation("presets")}, + }} + } + return nil +} + +func validateJobUsesCatalogAndSchema(b *bundle.Bundle, key string, job *resources.Job) diag.Diagnostics { + if !hasParameter(job.Parameters, "catalog") || !hasParameter(job.Parameters, "schema") { + return diag.Diagnostics{{ + Summary: fmt.Sprintf("job %s must pass catalog and schema presets as parameters as follows:\n"+ + " parameters:\n"+ + " - name: catalog:\n"+ + " default: ${presets.catalog}\n"+ + " - name: schema\n"+ + " default: ${presets.schema}\n", key), + Severity: diag.Error, + Locations: []dyn.Location{b.Config.GetLocation("resources.jobs." + key)}, + }} + } + return nil +} + +func hasParameter(parameters []jobs.JobParameterDefinition, name string) bool { + if parameters == nil { + return false + } + for _, p := range parameters { + if p.Name == name { + return true + } + } + return false +} + // toTagArray converts a map of tags to an array of tags. // We sort tags so ensure stable ordering. func toTagArray(tags map[string]string) []Tag { diff --git a/bundle/config/presets.go b/bundle/config/presets.go index 61009a252..948c9043a 100644 --- a/bundle/config/presets.go +++ b/bundle/config/presets.go @@ -19,6 +19,12 @@ type Presets struct { // Tags to add to all resources. Tags map[string]string `json:"tags,omitempty"` + + // Catalog is the default catalog for all resources. + Catalog string `json:"catalog,omitempty"` + + // Schema is the default schema for all resources. + Schema string `json:"schema,omitempty"` } // IsExplicitlyEnabled tests whether this feature is explicitly enabled. diff --git a/bundle/phases/initialize.go b/bundle/phases/initialize.go index 5582016fd..2e6d7dce9 100644 --- a/bundle/phases/initialize.go +++ b/bundle/phases/initialize.go @@ -61,6 +61,7 @@ func Initialize() bundle.Mutator { "bundle", "workspace", "variables", + "presets", ), // Provide permission config errors & warnings after initializing all variables permissions.PermissionDiagnostics(), diff --git a/libs/template/templates/default-python/template/{{.project_name}}/databricks.yml.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/databricks.yml.tmpl index c42b822a8..13e2e6c9c 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/databricks.yml.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/databricks.yml.tmpl @@ -16,6 +16,9 @@ targets: default: true workspace: host: {{workspace_host}} + presets: + catalog: {{default_catalog}} + schema: default prod: mode: production @@ -23,6 +26,9 @@ targets: host: {{workspace_host}} # We explicitly specify /Workspace/Users/{{user_name}} to make sure we only have a single copy. root_path: /Workspace/Users/{{user_name}}/.bundle/${bundle.name}/${bundle.target} + presets: + catalog: {{default_catalog}} + schema: default permissions: - {{if is_service_principal}}service_principal{{else}}user{{end}}_name: {{user_name}} level: CAN_MANAGE diff --git a/libs/template/templates/default-python/template/{{.project_name}}/resources/{{.project_name}}.job.yml.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/resources/{{.project_name}}.job.yml.tmpl index 5211e3894..17b209a11 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/resources/{{.project_name}}.job.yml.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/resources/{{.project_name}}.job.yml.tmpl @@ -10,22 +10,26 @@ resources: {{.project_name}}_job: name: {{.project_name}}_job + {{if or (eq .include_notebook "yes") (eq .include_python "yes") -}} + parameters: + - name: catalog + default: ${presets.catalog} + - name: schema + default: ${presets.schema} + + {{end -}} trigger: # Run this job every day, exactly one day from the last run; see https://docs.databricks.com/api/workspace/jobs/create#trigger periodic: interval: 1 unit: DAYS - {{- if not is_service_principal}} - + {{if not is_service_principal -}} email_notifications: on_failure: - {{user_name}} - {{else}} - {{end -}} - tasks: {{- if eq .include_notebook "yes" }} - task_key: notebook_task diff --git a/libs/template/templates/default-python/template/{{.project_name}}/src/notebook.ipynb.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/src/notebook.ipynb.tmpl index 6782a053b..4163389b3 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/src/notebook.ipynb.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/src/notebook.ipynb.tmpl @@ -23,10 +23,25 @@ "metadata": {}, "outputs": [], "source": [ + "# Automatically reload this notebook when it is edited\n", "%load_ext autoreload\n", "%autoreload 2" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Set the catalog and schema for the current session\n", + "dbutils.widgets.text('catalog', '{{default_catalog}}')\n", + "dbutils.widgets.text('schema', 'default')\n", + "catalog = dbutils.widgets.get('catalog')\n", + "schema = dbutils.widgets.get('schema')\n", + "spark.sql(f'USE {catalog}.{schema}')" + ] + }, { "cell_type": "code", "execution_count": 0, diff --git a/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl index c514c6dc5..1ac627a0c 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl @@ -15,7 +15,15 @@ def get_spark() -> SparkSession: return SparkSession.builder.getOrCreate() def main(): - get_taxis(get_spark()).show(5) + # Set the catalog and schema for the current session + parser = argparse.ArgumentParser() + parser.add_argument('--catalog', '-c', required=True) + parser.add_argument('--schema', '-s', required=True) + args, unknown = parser.parse_known_args() + spark = get_spark() + spark.sql(f"USE {args.catalog}.{args.schema}") + + get_taxis(spark).show(5) if __name__ == '__main__': main()