diff --git a/cmd/bundle/init.go b/cmd/bundle/init.go index 687c141ec..4da5a69be 100644 --- a/cmd/bundle/init.go +++ b/cmd/bundle/init.go @@ -4,152 +4,17 @@ import ( "context" "errors" "fmt" - "io/fs" - "os" "path/filepath" - "slices" "strings" "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/dbr" "github.com/databricks/cli/libs/filer" - "github.com/databricks/cli/libs/git" "github.com/databricks/cli/libs/template" "github.com/spf13/cobra" ) -var gitUrlPrefixes = []string{ - "https://", - "git@", -} - -type nativeTemplate struct { - name string - gitUrl string - description string - aliases []string - hidden bool -} - -const customTemplate = "custom..." - -var nativeTemplates = []nativeTemplate{ - { - name: "default-python", - description: "The default Python template for Notebooks / Delta Live Tables / Workflows", - }, - { - name: "default-sql", - description: "The default SQL template for .sql files that run with Databricks SQL", - }, - { - name: "dbt-sql", - description: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)", - }, - { - name: "mlops-stacks", - gitUrl: "https://github.com/databricks/mlops-stacks", - description: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)", - aliases: []string{"mlops-stack"}, - }, - { - name: "default-pydabs", - gitUrl: "https://databricks.github.io/workflows-authoring-toolkit/pydabs-template.git", - hidden: true, - description: "The default PyDABs template", - }, - { - name: customTemplate, - description: "Bring your own template", - }, -} - -// Return template descriptions for command-line help -func nativeTemplateHelpDescriptions() string { - var lines []string - for _, template := range nativeTemplates { - if template.name != customTemplate && !template.hidden { - lines = append(lines, fmt.Sprintf("- %s: %s", template.name, template.description)) - } - } - return strings.Join(lines, "\n") -} - -// Return template options for an interactive prompt -func nativeTemplateOptions() []cmdio.Tuple { - names := make([]cmdio.Tuple, 0, len(nativeTemplates)) - for _, template := range nativeTemplates { - if template.hidden { - continue - } - tuple := cmdio.Tuple{ - Name: template.name, - Id: template.description, - } - names = append(names, tuple) - } - return names -} - -func getNativeTemplateByDescription(description string) string { - for _, template := range nativeTemplates { - if template.description == description { - return template.name - } - } - return "" -} - -func getUrlForNativeTemplate(name string) string { - for _, template := range nativeTemplates { - if template.name == name { - return template.gitUrl - } - if slices.Contains(template.aliases, name) { - return template.gitUrl - } - } - return "" -} - -func getFsForNativeTemplate(name string) (fs.FS, error) { - builtin, err := template.Builtin() - if err != nil { - return nil, err - } - - // If this is a built-in template, the return value will be non-nil. - var templateFS fs.FS - for _, entry := range builtin { - if entry.Name == name { - templateFS = entry.FS - break - } - } - - return templateFS, nil -} - -func isRepoUrl(url string) bool { - result := false - for _, prefix := range gitUrlPrefixes { - if strings.HasPrefix(url, prefix) { - result = true - break - } - } - return result -} - -// Computes the repo name from the repo URL. Treats the last non empty word -// when splitting at '/' as the repo name. For example: for url git@github.com:databricks/cli.git -// the name would be "cli.git" -func repoName(url string) string { - parts := strings.Split(strings.TrimRight(url, "/"), "/") - return parts[len(parts)-1] -} - func constructOutputFiler(ctx context.Context, outputDir string) (filer.Filer, error) { outputDir, err := filepath.Abs(outputDir) if err != nil { @@ -182,7 +47,7 @@ TEMPLATE_PATH optionally specifies which template to use. It can be one of the f - a local file system path with a template directory - a Git repository URL, e.g. https://github.com/my/repository -See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more information on templates.`, nativeTemplateHelpDescriptions()), +See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more information on templates.`, template.HelpDescriptions()), } var configFile string @@ -196,7 +61,6 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf cmd.Flags().StringVar(&branch, "tag", "", "Git tag to use for template initialization") cmd.Flags().StringVar(&tag, "branch", "", "Git branch to use for template initialization") - cmd.PreRunE = root.MustWorkspaceClient cmd.RunE = func(cmd *cobra.Command, args []string) error { if tag != "" && branch != "" { return errors.New("only one of --tag or --branch can be specified") @@ -208,82 +72,51 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf ref = tag } + var tmpl *template.Template + var err error ctx := cmd.Context() - var templatePath string + if len(args) > 0 { - templatePath = args[0] - } else { - var err error - if !cmdio.IsPromptSupported(ctx) { - return errors.New("please specify a template") + // User already specified a template local path or a Git URL. Use that + // information to configure a reader for the template + tmpl = template.Get(template.Custom) + // TODO: Get rid of the name arg. + if template.IsGitRepoUrl(args[0]) { + tmpl.SetReader(template.NewGitReader("", args[0], ref, templateDir)) + } else { + tmpl.SetReader(template.NewLocalReader("", args[0])) + } + } else { + tmplId, err := template.PromptForTemplateId(cmd.Context(), ref, templateDir) + if tmplId == template.Custom { + // If a user selects custom during the prompt, ask them to provide a path or Git URL + // as a positional argument. + cmdio.LogString(ctx, "Please specify a path or Git repository to use a custom template.") + cmdio.LogString(ctx, "See https://docs.databricks.com/en/dev-tools/bundles/templates.html to learn more about custom templates.") + return nil } - description, err := cmdio.SelectOrdered(ctx, nativeTemplateOptions(), "Template to use") if err != nil { return err } - templatePath = getNativeTemplateByDescription(description) + + tmpl = template.Get(tmplId) } + defer tmpl.Reader.Close() + outputFiler, err := constructOutputFiler(ctx, outputDir) if err != nil { return err } - if templatePath == customTemplate { - cmdio.LogString(ctx, "Please specify a path or Git repository to use a custom template.") - cmdio.LogString(ctx, "See https://docs.databricks.com/en/dev-tools/bundles/templates.html to learn more about custom templates.") - return nil - } + tmpl.Writer.Initialize(tmpl.Reader, configFile, outputFiler) - // Expand templatePath to a git URL if it's an alias for a known native template - // and we know it's git URL. - if gitUrl := getUrlForNativeTemplate(templatePath); gitUrl != "" { - templatePath = gitUrl - } - - if !isRepoUrl(templatePath) { - if templateDir != "" { - return errors.New("--template-dir can only be used with a Git repository URL") - } - - templateFS, err := getFsForNativeTemplate(templatePath) - if err != nil { - return err - } - - // If this is not a built-in template, then it must be a local file system path. - if templateFS == nil { - templateFS = os.DirFS(templatePath) - } - - // skip downloading the repo because input arg is not a URL. We assume - // it's a path on the local file system in that case - return template.Materialize(ctx, configFile, templateFS, outputFiler) - } - - // Create a temporary directory with the name of the repository. The '*' - // character is replaced by a random string in the generated temporary directory. - repoDir, err := os.MkdirTemp("", repoName(templatePath)+"-*") + err = tmpl.Writer.Materialize(ctx) if err != nil { return err } - // start the spinner - promptSpinner := cmdio.Spinner(ctx) - promptSpinner <- "Downloading the template\n" - - // TODO: Add automated test that the downloaded git repo is cleaned up. - // Clone the repository in the temporary directory - err = git.Clone(ctx, templatePath, ref, repoDir) - close(promptSpinner) - if err != nil { - return err - } - - // Clean up downloaded repository once the template is materialized. - defer os.RemoveAll(repoDir) - templateFS := os.DirFS(filepath.Join(repoDir, templateDir)) - return template.Materialize(ctx, configFile, templateFS, outputFiler) + return tmpl.Writer.LogTelemetry(ctx) } return cmd } diff --git a/integration/bundle/helpers_test.go b/integration/bundle/helpers_test.go index e884cd8c6..60177297e 100644 --- a/integration/bundle/helpers_test.go +++ b/integration/bundle/helpers_test.go @@ -8,18 +8,13 @@ import ( "os" "os/exec" "path/filepath" - "strings" "github.com/databricks/cli/bundle" - "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/internal/testcli" "github.com/databricks/cli/internal/testutil" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/env" - "github.com/databricks/cli/libs/filer" - "github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/folders" - "github.com/databricks/cli/libs/template" "github.com/databricks/databricks-sdk-go" "github.com/stretchr/testify/require" ) @@ -32,19 +27,32 @@ func initTestTemplate(t testutil.TestingT, ctx context.Context, templateName str } func initTestTemplateWithBundleRoot(t testutil.TestingT, ctx context.Context, templateName string, config map[string]any, bundleRoot string) string { - templateRoot := filepath.Join("bundles", templateName) + return "" - configFilePath := writeConfigFile(t, config) + // TODO: Make this function work but do not log telemetry. - ctx = root.SetWorkspaceClient(ctx, nil) - cmd := cmdio.NewIO(ctx, flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "", "bundles") - ctx = cmdio.InContext(ctx, cmd) + // templateRoot := filepath.Join("bundles", templateName) - out, err := filer.NewLocalClient(bundleRoot) - require.NoError(t, err) - err = template.Materialize(ctx, configFilePath, os.DirFS(templateRoot), out) - require.NoError(t, err) - return bundleRoot + // configFilePath := writeConfigFile(t, config) + + // ctx = root.SetWorkspaceClient(ctx, nil) + // cmd := cmdio.NewIO(ctx, flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "", "bundles") + // ctx = cmdio.InContext(ctx, cmd) + // ctx = telemetry.WithMockLogger(ctx) + + // out, err := filer.NewLocalClient(bundleRoot) + // require.NoError(t, err) + // tmpl := template.TemplateX{ + // TemplateOpts: template.TemplateOpts{ + // ConfigFilePath: configFilePath, + // TemplateFS: os.DirFS(templateRoot), + // OutputFiler: out, + // }, + // } + + // err = tmpl.Materialize(ctx) + // require.NoError(t, err) + // return bundleRoot } func writeConfigFile(t testutil.TestingT, config map[string]any) string { diff --git a/integration/bundle/init_test.go b/integration/bundle/init_test.go index f5c263ca3..3826f5543 100644 --- a/integration/bundle/init_test.go +++ b/integration/bundle/init_test.go @@ -15,6 +15,7 @@ import ( "github.com/databricks/cli/internal/testcli" "github.com/databricks/cli/internal/testutil" "github.com/databricks/cli/libs/iamutil" + "github.com/databricks/cli/libs/telemetry" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -42,6 +43,9 @@ func TestBundleInitOnMlopsStacks(t *testing.T) { ctx, wt := acc.WorkspaceTest(t) w := wt.W + // Use mock logger to introspect the telemetry payload. + ctx = telemetry.WithMockLogger(ctx) + tmpDir1 := t.TempDir() tmpDir2 := t.TempDir() @@ -64,6 +68,28 @@ func TestBundleInitOnMlopsStacks(t *testing.T) { assert.NoFileExists(t, filepath.Join(tmpDir2, "repo_name", projectName, "README.md")) testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", "mlops-stacks", "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir1, "config.json")) + // Assert the telemetry payload is correctly logged. + tlmyEvents := telemetry.Introspect(ctx) + require.Len(t, telemetry.Introspect(ctx), 1) + event := tlmyEvents[0].BundleInitEvent + assert.Equal(t, "mlops-stacks", event.TemplateName) + + get := func(key string) string { + for _, v := range event.TemplateEnumArgs { + if v.Key == key { + return v.Value + } + } + return "" + } + + // Enum values should be present in the telemetry payload. + assert.Equal(t, "no", get("input_include_models_in_unity_catalog")) + assert.Equal(t, strings.ToLower(env), get("input_cloud")) + // Freeform strings should not be present in the telemetry payload. + assert.Equal(t, "", get("input_project_name")) + assert.Equal(t, "", get("input_root_dir")) + // Assert that the README.md file was created contents := testutil.ReadFile(t, filepath.Join(tmpDir2, "repo_name", projectName, "README.md")) assert.Contains(t, contents, fmt.Sprintf("# %s", projectName)) @@ -99,6 +125,156 @@ func TestBundleInitOnMlopsStacks(t *testing.T) { assert.Contains(t, job.Settings.Name, fmt.Sprintf("dev-%s-batch-inference-job", projectName)) } +func TestBundleInitTelemetryForDefaultTemplates(t *testing.T) { + projectName := testutil.RandomName("name_") + + tcases := []struct { + name string + args map[string]string + expectedArgs map[string]string + }{ + { + name: "dbt-sql", + args: map[string]string{ + "project_name": fmt.Sprintf("dbt-sql-%s", projectName), + "http_path": "/sql/1.0/warehouses/id", + "default_catalog": "abcd", + "personal_schemas": "yes, use a schema based on the current user name during development", + }, + expectedArgs: map[string]string{ + "personal_schemas": "yes, use a schema based on the current user name during development", + }, + }, + { + name: "default-python", + args: map[string]string{ + "project_name": fmt.Sprintf("default_python_%s", projectName), + "include_notebook": "yes", + "include_dlt": "yes", + "include_python": "no", + }, + expectedArgs: map[string]string{ + "include_notebook": "yes", + "include_dlt": "yes", + "include_python": "no", + }, + }, + { + name: "default-sql", + args: map[string]string{ + "project_name": fmt.Sprintf("sql_project_%s", projectName), + "http_path": "/sql/1.0/warehouses/id", + "default_catalog": "abcd", + "personal_schemas": "yes, automatically use a schema based on the current user name during development", + }, + expectedArgs: map[string]string{ + "personal_schemas": "yes, automatically use a schema based on the current user name during development", + }, + }, + } + + for _, tc := range tcases { + ctx, _ := acc.WorkspaceTest(t) + + // Use mock logger to introspect the telemetry payload. + ctx = telemetry.WithMockLogger(ctx) + + tmpDir1 := t.TempDir() + tmpDir2 := t.TempDir() + + // Create a config file with the project name and root dir + initConfig := tc.args + b, err := json.Marshal(initConfig) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir1, "config.json"), b, 0o644) + require.NoError(t, err) + + // Run bundle init + assert.NoDirExists(t, filepath.Join(tmpDir2, tc.args["project_name"])) + testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", tc.name, "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir1, "config.json")) + assert.DirExists(t, filepath.Join(tmpDir2, tc.args["project_name"])) + + // Assert the telemetry payload is correctly logged. + logs := telemetry.Introspect(ctx) + require.Len(t, logs, 1) + event := logs[0].BundleInitEvent + assert.Equal(t, event.TemplateName, tc.name) + + get := func(key string) string { + for _, v := range event.TemplateEnumArgs { + if v.Key == key { + return v.Value + } + } + return "" + } + + // Assert the template enum args are correctly logged. + assert.Len(t, event.TemplateEnumArgs, len(tc.expectedArgs)) + for k, v := range tc.expectedArgs { + assert.Equal(t, get(k), v) + } + } +} + +func TestBundleInitTelemetryForCustomTemplates(t *testing.T) { + ctx, _ := acc.WorkspaceTest(t) + + tmpDir1 := t.TempDir() + tmpDir2 := t.TempDir() + tmpDir3 := t.TempDir() + + err := os.Mkdir(filepath.Join(tmpDir1, "template"), 0o755) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir1, "template", "foo.txt.tmpl"), []byte("{{bundle_uuid}}"), 0o644) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir1, "databricks_template_schema.json"), []byte(` +{ + "properties": { + "a": { + "description": "whatever", + "type": "string" + }, + "b": { + "description": "whatever", + "type": "string", + "enum": ["yes", "no"] + } + } +} +`), 0o644) + require.NoError(t, err) + + // Create a config file with the project name and root dir + initConfig := map[string]string{ + "a": "v1", + "b": "yes", + } + b, err := json.Marshal(initConfig) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir3, "config.json"), b, 0o644) + require.NoError(t, err) + + // Use mock logger to introspect the telemetry payload. + ctx = telemetry.WithMockLogger(ctx) + + // Run bundle init. + testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", tmpDir1, "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir3, "config.json")) + + // Assert the telemetry payload is correctly logged. For custom templates we should + // never set template_enum_args. + tlmyEvents := telemetry.Introspect(ctx) + require.Len(t, len(tlmyEvents), 1) + event := tlmyEvents[0].BundleInitEvent + assert.Equal(t, "custom", event.TemplateName) + assert.Empty(t, event.TemplateEnumArgs) + + // Ensure that the UUID returned by the `bundle_uuid` helper is the same UUID + // that's logged in the telemetry event. + fileC := testutil.ReadFile(t, filepath.Join(tmpDir2, "foo.txt")) + assert.Equal(t, event.Uuid, fileC) +} + func TestBundleInitHelpers(t *testing.T) { ctx, wt := acc.WorkspaceTest(t) w := wt.W diff --git a/libs/template/builtin.go b/libs/template/builtin.go index dcb3a8858..96cdcbb96 100644 --- a/libs/template/builtin.go +++ b/libs/template/builtin.go @@ -15,6 +15,7 @@ type BuiltinTemplate struct { } // Builtin returns the list of all built-in templates. +// TODO: Make private? func Builtin() ([]BuiltinTemplate, error) { templates, err := fs.Sub(builtinTemplates, "templates") if err != nil { diff --git a/libs/template/config.go b/libs/template/config.go index 8e7695b91..34eee065c 100644 --- a/libs/template/config.go +++ b/libs/template/config.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io/fs" + "slices" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/jsonschema" @@ -273,3 +274,23 @@ func (c *config) validate() error { } return nil } + +// Return enum values selected by the user during template initialization. These +// values are safe to send over in telemetry events due to their limited cardinality. +func (c *config) enumValues() map[string]string { + res := map[string]string{} + for k, p := range c.schema.Properties { + if p.Type != jsonschema.StringType { + continue + } + if p.Enum == nil { + continue + } + v := c.values[k] + + if slices.Contains(p.Enum, v) { + res[k] = v.(string) + } + } + return res +} diff --git a/libs/template/config_test.go b/libs/template/config_test.go index 515a0b9f5..3f971a862 100644 --- a/libs/template/config_test.go +++ b/libs/template/config_test.go @@ -564,3 +564,42 @@ func TestPromptIsSkippedAnyOf(t *testing.T) { assert.True(t, skip) assert.Equal(t, "hello-world", c.values["xyz"]) } + +func TestConfigEnumValues(t *testing.T) { + c := &config{ + schema: &jsonschema.Schema{ + Properties: map[string]*jsonschema.Schema{ + "a": { + Type: jsonschema.StringType, + }, + "b": { + Type: jsonschema.BooleanType, + }, + "c": { + Type: jsonschema.StringType, + Enum: []any{"v1", "v2"}, + }, + "d": { + Type: jsonschema.StringType, + Enum: []any{"v3", "v4"}, + }, + "e": { + Type: jsonschema.StringType, + Enum: []any{"v5", "v6"}, + }, + }, + }, + values: map[string]any{ + "a": "w1", + "b": false, + "c": "v1", + "d": "v3", + "e": "v7", + }, + } + + assert.Equal(t, map[string]string{ + "c": "v1", + "d": "v3", + }, c.enumValues()) +} diff --git a/libs/template/materialize.go b/libs/template/materialize.go deleted file mode 100644 index 86a6a8c37..000000000 --- a/libs/template/materialize.go +++ /dev/null @@ -1,94 +0,0 @@ -package template - -import ( - "context" - "errors" - "fmt" - "io/fs" - - "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/filer" -) - -const ( - libraryDirName = "library" - templateDirName = "template" - schemaFileName = "databricks_template_schema.json" -) - -// This function materializes the input templates as a project, using user defined -// configurations. -// Parameters: -// -// ctx: context containing a cmdio object. This is used to prompt the user -// configFilePath: file path containing user defined config values -// templateFS: root of the template definition -// outputFiler: filer to use for writing the initialized template -func Materialize(ctx context.Context, configFilePath string, templateFS fs.FS, outputFiler filer.Filer) error { - if _, err := fs.Stat(templateFS, schemaFileName); errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("not a bundle template: expected to find a template schema file at %s", schemaFileName) - } - - config, err := newConfig(ctx, templateFS, schemaFileName) - if err != nil { - return err - } - - // Read and assign config values from file - if configFilePath != "" { - err = config.assignValuesFromFile(configFilePath) - if err != nil { - return err - } - } - - helpers := loadHelpers(ctx) - r, err := newRenderer(ctx, config.values, helpers, templateFS, templateDirName, libraryDirName) - if err != nil { - return err - } - - // Print welcome message - welcome := config.schema.WelcomeMessage - if welcome != "" { - welcome, err = r.executeTemplate(welcome) - if err != nil { - return err - } - cmdio.LogString(ctx, welcome) - } - - // Prompt user for any missing config values. Assign default values if - // terminal is not TTY - err = config.promptOrAssignDefaultValues(r) - if err != nil { - return err - } - err = config.validate() - if err != nil { - return err - } - - // Walk and render the template, since input configuration is complete - err = r.walk() - if err != nil { - return err - } - - err = r.persistToDisk(ctx, outputFiler) - if err != nil { - return err - } - - success := config.schema.SuccessMessage - if success == "" { - cmdio.LogString(ctx, "✨ Successfully initialized template") - } else { - success, err = r.executeTemplate(success) - if err != nil { - return err - } - cmdio.LogString(ctx, success) - } - return nil -} diff --git a/libs/template/materialize_test.go b/libs/template/materialize_test.go deleted file mode 100644 index f7cd916e3..000000000 --- a/libs/template/materialize_test.go +++ /dev/null @@ -1,24 +0,0 @@ -package template - -import ( - "context" - "fmt" - "os" - "testing" - - "github.com/databricks/cli/cmd/root" - "github.com/databricks/databricks-sdk-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMaterializeForNonTemplateDirectory(t *testing.T) { - tmpDir := t.TempDir() - w, err := databricks.NewWorkspaceClient(&databricks.Config{}) - require.NoError(t, err) - ctx := root.SetWorkspaceClient(context.Background(), w) - - // Try to materialize a non-template directory. - err = Materialize(ctx, "", os.DirFS(tmpDir), nil) - assert.EqualError(t, err, fmt.Sprintf("not a bundle template: expected to find a template schema file at %s", schemaFileName)) -} diff --git a/libs/template/reader.go b/libs/template/reader.go new file mode 100644 index 000000000..6cfaf9cb6 --- /dev/null +++ b/libs/template/reader.go @@ -0,0 +1,199 @@ +package template + +import ( + "context" + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/git" +) + +// TODO: Add tests for all these readers. +type Reader interface { + // FS returns a file system that contains the template + // definition files. This function is NOT thread safe. + FS(ctx context.Context) (fs.FS, error) + + // Close releases any resources associated with the reader + // like cleaning up temporary directories. + Close() error + + Name() string +} + +type builtinReader struct { + name string + fsCached fs.FS +} + +func (r *builtinReader) FS(ctx context.Context) (fs.FS, error) { + // If the FS has already been loaded, return it. + if r.fsCached != nil { + return r.fsCached, nil + } + + builtin, err := Builtin() + if err != nil { + return nil, err + } + + var templateFS fs.FS + for _, entry := range builtin { + if entry.Name == r.name { + templateFS = entry.FS + break + } + } + + r.fsCached = templateFS + return r.fsCached, nil +} + +func (r *builtinReader) Close() error { + return nil +} + +func (r *builtinReader) Name() string { + return r.name +} + +type gitReader struct { + name string + // URL of the git repository that contains the template + gitUrl string + // tag or branch to checkout + ref string + // subdirectory within the repository that contains the template + templateDir string + // temporary directory where the repository is cloned + tmpRepoDir string + + fsCached fs.FS +} + +// Computes the repo name from the repo URL. Treats the last non empty word +// when splitting at '/' as the repo name. For example: for url git@github.com:databricks/cli.git +// the name would be "cli.git" +func repoName(url string) string { + parts := strings.Split(strings.TrimRight(url, "/"), "/") + return parts[len(parts)-1] +} + +var gitUrlPrefixes = []string{ + "https://", + "git@", +} + +// TODO: Copy over tests for this function. +func IsGitRepoUrl(url string) bool { + result := false + for _, prefix := range gitUrlPrefixes { + if strings.HasPrefix(url, prefix) { + result = true + break + } + } + return result +} + +// TODO: Can I remove the name from here and other readers? +func NewGitReader(name, gitUrl, ref, templateDir string) Reader { + return &gitReader{ + name: name, + gitUrl: gitUrl, + ref: ref, + templateDir: templateDir, + } +} + +// TODO: Test the idempotency of this function as well. +func (r *gitReader) FS(ctx context.Context) (fs.FS, error) { + // If the FS has already been loaded, return it. + if r.fsCached != nil { + return r.fsCached, nil + } + + // Create a temporary directory with the name of the repository. The '*' + // character is replaced by a random string in the generated temporary directory. + repoDir, err := os.MkdirTemp("", repoName(r.gitUrl)+"-*") + if err != nil { + return nil, err + } + r.tmpRepoDir = repoDir + + // start the spinner + promptSpinner := cmdio.Spinner(ctx) + promptSpinner <- "Downloading the template\n" + + err = git.Clone(ctx, r.gitUrl, r.ref, repoDir) + close(promptSpinner) + if err != nil { + return nil, err + } + + r.fsCached = os.DirFS(filepath.Join(repoDir, r.templateDir)) + return r.fsCached, nil +} + +func (r *gitReader) Close() error { + if r.tmpRepoDir == "" { + return nil + } + + return os.RemoveAll(r.tmpRepoDir) +} + +func (r *gitReader) Name() string { + return r.name +} + +type localReader struct { + name string + // Path on the local filesystem that contains the template + path string + + fsCached fs.FS +} + +func NewLocalReader(name, path string) Reader { + return &localReader{ + name: name, + path: path, + } +} + +func (r *localReader) FS(ctx context.Context) (fs.FS, error) { + // If the FS has already been loaded, return it. + if r.fsCached != nil { + return r.fsCached, nil + } + + r.fsCached = os.DirFS(r.path) + return r.fsCached, nil +} + +func (r *localReader) Close() error { + return nil +} + +func (r *localReader) Name() string { + return r.name +} + +type failReader struct{} + +func (r *failReader) FS(ctx context.Context) (fs.FS, error) { + return nil, fmt.Errorf("this is a placeholder reader that always fails. Please configure a real reader.") +} + +func (r *failReader) Close() error { + return fmt.Errorf("this is a placeholder reader that always fails. Please configure a real reader.") +} + +func (r *failReader) Name() string { + return "failReader" +} diff --git a/libs/template/template.go b/libs/template/template.go new file mode 100644 index 000000000..1467ff2e5 --- /dev/null +++ b/libs/template/template.go @@ -0,0 +1,145 @@ +package template + +import ( + "context" + "fmt" + "strings" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/filer" +) + +type Template struct { + Reader Reader + Writer Writer + + Id TemplateId + Description string + Aliases []string + Hidden bool +} + +// TODO: Make details private? +// TODO: Combine this with the generic template struct? +type NativeTemplate struct { + Name string + Description string + Aliases []string + GitUrl string + Hidden bool + IsOwnedByDatabricks bool +} + +type TemplateId string + +const ( + DefaultPython TemplateId = "default-python" + DefaultSql TemplateId = "default-sql" + DbtSql TemplateId = "dbt-sql" + MlopsStacks TemplateId = "mlops-stacks" + DefaultPydabs TemplateId = "default-pydabs" + Custom TemplateId = "custom" +) + +var allTemplates = []Template{ + { + Id: DefaultPython, + Description: "The default Python template for Notebooks / Delta Live Tables / Workflows", + Reader: &builtinReader{name: "default-python"}, + Writer: &writerWithTelemetry{}, + }, + { + Id: DefaultSql, + Description: "The default SQL template for .sql files that run with Databricks SQL", + Reader: &builtinReader{name: "default-sql"}, + Writer: &writerWithTelemetry{}, + }, + { + Id: DbtSql, + Description: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)", + Reader: &builtinReader{name: "dbt-sql"}, + Writer: &writerWithTelemetry{}, + }, + { + Id: MlopsStacks, + Description: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)", + Aliases: []string{"mlops-stack"}, + Reader: &gitReader{gitUrl: "https://github.com/databricks/mlops-stacks"}, + Writer: &writerWithTelemetry{}, + }, + { + Id: DefaultPydabs, + Hidden: true, + Description: "The default PyDABs template", + Reader: &gitReader{gitUrl: "https://databricks.github.io/workflows-authoring-toolkit/pydabs-template.git"}, + Writer: &writerWithTelemetry{}, + }, + { + Id: Custom, + Description: "Bring your own template", + Reader: &failReader{}, + Writer: &defaultWriter{}, + }, +} + +func HelpDescriptions() string { + var lines []string + for _, template := range allTemplates { + if template.Id != Custom && !template.Hidden { + lines = append(lines, fmt.Sprintf("- %s: %s", template.Id, template.Description)) + } + } + return strings.Join(lines, "\n") +} + +func options() []cmdio.Tuple { + names := make([]cmdio.Tuple, 0, len(allTemplates)) + for _, template := range allTemplates { + if template.Hidden { + continue + } + tuple := cmdio.Tuple{ + Name: string(template.Id), + Id: template.Description, + } + names = append(names, tuple) + } + return names +} + +// TODO CONTINUE defining the methods that the init command will finally rely on. +func PromptForTemplateId(ctx context.Context, ref, templateDir string) (TemplateId, error) { + if !cmdio.IsPromptSupported(ctx) { + return "", fmt.Errorf("please specify a template") + } + description, err := cmdio.SelectOrdered(ctx, options(), "Template to use") + if err != nil { + return "", err + } + + for _, template := range allTemplates { + if template.Description == description { + return template.Id, nil + } + } + + panic("this should never happen - template not found") +} + +func (tmpl *Template) InitializeWriter(configPath string, outputFiler filer.Filer) { + tmpl.Writer.Initialize(tmpl.Reader, configPath, outputFiler) +} + +func (tmpl *Template) SetReader(r Reader) { + tmpl.Reader = r +} + +func Get(id TemplateId) *Template { + for _, template := range allTemplates { + if template.Id == id { + return &template + } + } + + return nil +} diff --git a/cmd/bundle/init_test.go b/libs/template/template_test.go similarity index 64% rename from cmd/bundle/init_test.go rename to libs/template/template_test.go index 475b2e149..6b6ca0d0e 100644 --- a/cmd/bundle/init_test.go +++ b/libs/template/template_test.go @@ -1,4 +1,4 @@ -package bundle +package template import ( "testing" @@ -7,12 +7,31 @@ import ( "github.com/stretchr/testify/assert" ) -func TestBundleInitIsRepoUrl(t *testing.T) { - assert.True(t, isRepoUrl("git@github.com:databricks/cli.git")) - assert.True(t, isRepoUrl("https://github.com/databricks/cli.git")) +func TestTemplateHelpDescriptions(t *testing.T) { + expected := `- default-python: The default Python template for Notebooks / Delta Live Tables / Workflows +- default-sql: The default SQL template for .sql files that run with Databricks SQL +- dbt-sql: The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks) +- mlops-stacks: The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)` + assert.Equal(t, expected, HelpDescriptions()) +} - assert.False(t, isRepoUrl("./local")) - assert.False(t, isRepoUrl("foo")) +func TestTemplateOptions(t *testing.T) { + expected := []cmdio.Tuple{ + {Name: "default-python", Id: "The default Python template for Notebooks / Delta Live Tables / Workflows"}, + {Name: "default-sql", Id: "The default SQL template for .sql files that run with Databricks SQL"}, + {Name: "dbt-sql", Id: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)"}, + {Name: "mlops-stacks", Id: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)"}, + {Name: "custom", Id: "Bring your own template"}, + } + assert.Equal(t, expected, options()) +} + +func TestBundleInitIsRepoUrl(t *testing.T) { + assert.True(t, IsGitRepoUrl("git@github.com:databricks/cli.git")) + assert.True(t, IsGitRepoUrl("https://github.com/databricks/cli.git")) + + assert.False(t, IsGitRepoUrl("./local")) + assert.False(t, IsGitRepoUrl("foo")) } func TestBundleInitRepoName(t *testing.T) { @@ -26,29 +45,3 @@ func TestBundleInitRepoName(t *testing.T) { assert.Equal(t, "invalid-url", repoName("invalid-url")) assert.Equal(t, "www.github.com", repoName("https://www.github.com")) } - -func TestNativeTemplateOptions(t *testing.T) { - expected := []cmdio.Tuple{ - {Name: "default-python", Id: "The default Python template for Notebooks / Delta Live Tables / Workflows"}, - {Name: "default-sql", Id: "The default SQL template for .sql files that run with Databricks SQL"}, - {Name: "dbt-sql", Id: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)"}, - {Name: "mlops-stacks", Id: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)"}, - {Name: "custom...", Id: "Bring your own template"}, - } - assert.Equal(t, expected, nativeTemplateOptions()) -} - -func TestNativeTemplateHelpDescriptions(t *testing.T) { - expected := `- default-python: The default Python template for Notebooks / Delta Live Tables / Workflows -- default-sql: The default SQL template for .sql files that run with Databricks SQL -- dbt-sql: The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks) -- mlops-stacks: The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)` - assert.Equal(t, expected, nativeTemplateHelpDescriptions()) -} - -func TestGetUrlForNativeTemplate(t *testing.T) { - assert.Equal(t, "https://github.com/databricks/mlops-stacks", getUrlForNativeTemplate("mlops-stacks")) - assert.Equal(t, "https://github.com/databricks/mlops-stacks", getUrlForNativeTemplate("mlops-stack")) - assert.Equal(t, "", getUrlForNativeTemplate("default-python")) - assert.Equal(t, "", getUrlForNativeTemplate("invalid")) -} diff --git a/libs/template/writer.go b/libs/template/writer.go new file mode 100644 index 000000000..29a5deb69 --- /dev/null +++ b/libs/template/writer.go @@ -0,0 +1,169 @@ +package template + +import ( + "context" + "errors" + "fmt" + "io/fs" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/filer" +) + +// TODO: Retain coverage for the missing schema test case +// func TestMaterializeForNonTemplateDirectory(t *testing.T) { +// tmpDir := t.TempDir() +// w, err := databricks.NewWorkspaceClient(&databricks.Config{}) +// require.NoError(t, err) +// ctx := root.SetWorkspaceClient(context.Background(), w) + +// tmpl := TemplateX{ +// TemplateOpts: TemplateOpts{ +// ConfigFilePath: "", +// TemplateFS: os.DirFS(tmpDir), +// OutputFiler: nil, +// }, +// } + +// // Try to materialize a non-template directory. +// err = tmpl.Materialize(ctx) +// assert.EqualError(t, err, fmt.Sprintf("not a bundle template: expected to find a template schema file at %s", schemaFileName)) +// } + + +// TODO: Add tests for these writers, mocking the cmdio library +// at the same time. +const ( + libraryDirName = "library" + templateDirName = "template" + schemaFileName = "databricks_template_schema.json" +) + +type Writer interface { + Initialize(reader Reader, configPath string, outputFiler filer.Filer) + Materialize(ctx context.Context) error + LogTelemetry(ctx context.Context) error +} + +type defaultWriter struct { + reader Reader + configPath string + outputFiler filer.Filer + + // Internal state + config *config + renderer *renderer +} + +func (tmpl *defaultWriter) Initialize(reader Reader, configPath string, outputFiler filer.Filer) { + tmpl.configPath = configPath + tmpl.outputFiler = outputFiler +} + +func (tmpl *defaultWriter) promptForInput(ctx context.Context) error { + readerFs, err := tmpl.reader.FS(ctx) + if err != nil { + return err + } + if _, err := fs.Stat(readerFs, schemaFileName); errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("not a bundle template: expected to find a template schema file at %s", schemaFileName) + } + + tmpl.config, err = newConfig(ctx, readerFs, schemaFileName) + if err != nil { + return err + } + + // Read and assign config values from file + if tmpl.configPath != "" { + err = tmpl.config.assignValuesFromFile(tmpl.configPath) + if err != nil { + return err + } + } + + helpers := loadHelpers(ctx) + tmpl.renderer, err = newRenderer(ctx, tmpl.config.values, helpers, readerFs, templateDirName, libraryDirName) + if err != nil { + return err + } + + // Print welcome message + welcome := tmpl.config.schema.WelcomeMessage + if welcome != "" { + welcome, err = tmpl.renderer.executeTemplate(welcome) + if err != nil { + return err + } + cmdio.LogString(ctx, welcome) + } + + // Prompt user for any missing config values. Assign default values if + // terminal is not TTY + err = tmpl.config.promptOrAssignDefaultValues(tmpl.renderer) + if err != nil { + return err + } + return tmpl.config.validate() +} + +func (tmpl *defaultWriter) printSuccessMessage(ctx context.Context) error { + success := tmpl.config.schema.SuccessMessage + if success == "" { + cmdio.LogString(ctx, "✨ Successfully initialized template") + return nil + } + + success, err := tmpl.renderer.executeTemplate(success) + if err != nil { + return err + } + cmdio.LogString(ctx, success) + return nil +} + +func (tmpl *defaultWriter) Materialize(ctx context.Context) error { + err := tmpl.promptForInput(ctx) + if err != nil { + return err + } + + // Walk the template file tree and compute in-memory representations of the + // output files. + err = tmpl.renderer.walk() + if err != nil { + return err + } + + // Flush the output files to disk. + err = tmpl.renderer.persistToDisk(ctx, tmpl.outputFiler) + if err != nil { + return err + } + + return tmpl.printSuccessMessage(ctx) +} + +func (tmpl *defaultWriter) LogTelemetry(ctx context.Context) error { + // no-op + return nil +} + +type writerWithTelemetry struct { + defaultWriter +} + +func (tmpl *writerWithTelemetry) LogTelemetry(ctx context.Context) error { + // Log telemetry. TODO. + return nil +} + +func NewWriterWithTelemetry(reader Reader, configPath string, outputFiler filer.Filer) Writer { + return &writerWithTelemetry{ + defaultWriter: defaultWriter{ + reader: reader, + configPath: configPath, + outputFiler: outputFiler, + }, + } +}