From d238dd833cf0e6019a309324d92d43926b4fff05 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Fri, 3 Jan 2025 17:28:22 +0530 Subject: [PATCH] add resolver --- cmd/bundle/init.go | 90 +++++++-------------------------------- libs/template/reader.go | 43 ++----------------- libs/template/resolve.go | 70 ++++++++++++++++++++++++++++++ libs/template/template.go | 52 ++++++++++------------ libs/template/writer.go | 61 +++++++++++++++++--------- 5 files changed, 151 insertions(+), 165 deletions(-) create mode 100644 libs/template/resolve.go diff --git a/cmd/bundle/init.go b/cmd/bundle/init.go index 4da5a69be..307e367d6 100644 --- a/cmd/bundle/init.go +++ b/cmd/bundle/init.go @@ -1,40 +1,15 @@ package bundle import ( - "context" "errors" "fmt" - "path/filepath" - "strings" "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/dbr" - "github.com/databricks/cli/libs/filer" "github.com/databricks/cli/libs/template" "github.com/spf13/cobra" ) -func constructOutputFiler(ctx context.Context, outputDir string) (filer.Filer, error) { - outputDir, err := filepath.Abs(outputDir) - if err != nil { - return nil, err - } - - // If the CLI is running on DBR and we're writing to the workspace file system, - // use the extension-aware workspace filesystem filer to instantiate the template. - // - // It is not possible to write notebooks through the workspace filesystem's FUSE mount. - // Therefore this is the only way we can initialize templates that contain notebooks - // when running the CLI on DBR and initializing a template to the workspace. - // - if strings.HasPrefix(outputDir, "/Workspace/") && dbr.RunsOnRuntime(ctx) { - return filer.NewWorkspaceFilesExtensionsClient(root.WorkspaceClient(ctx), outputDir) - } - - return filer.NewLocalClient(outputDir) -} - func newInitCommand() *cobra.Command { cmd := &cobra.Command{ Use: "init [TEMPLATE_PATH]", @@ -62,61 +37,28 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf cmd.Flags().StringVar(&tag, "branch", "", "Git branch to use for template initialization") cmd.RunE = func(cmd *cobra.Command, args []string) error { - if tag != "" && branch != "" { - return errors.New("only one of --tag or --branch can be specified") + r := template.Resolver{ + TemplatePathOrUrl: args[0], + ConfigFile: configFile, + OutputDir: outputDir, + TemplateDir: templateDir, + Tag: tag, + Branch: branch, } - // Git ref to use for template initialization - ref := branch - if tag != "" { - ref = tag - } - - var tmpl *template.Template - var err error ctx := cmd.Context() - - if len(args) > 0 { - // User already specified a template local path or a Git URL. Use that - // information to configure a reader for the template - tmpl = template.Get(template.Custom) - // TODO: Get rid of the name arg. - if template.IsGitRepoUrl(args[0]) { - tmpl.SetReader(template.NewGitReader("", args[0], ref, templateDir)) - } else { - tmpl.SetReader(template.NewLocalReader("", args[0])) - } - } else { - tmplId, err := template.PromptForTemplateId(cmd.Context(), ref, templateDir) - if tmplId == template.Custom { - // If a user selects custom during the prompt, ask them to provide a path or Git URL - // as a positional argument. - cmdio.LogString(ctx, "Please specify a path or Git repository to use a custom template.") - cmdio.LogString(ctx, "See https://docs.databricks.com/en/dev-tools/bundles/templates.html to learn more about custom templates.") - return nil - } - if err != nil { - return err - } - - tmpl = template.Get(tmplId) + tmpl, err := r.Resolve(ctx) + if errors.Is(err, template.ErrCustomSelected) { + cmdio.LogString(ctx, "Please specify a path or Git repository to use a custom template.") + cmdio.LogString(ctx, "See https://docs.databricks.com/en/dev-tools/bundles/templates.html to learn more about custom templates.") + return nil + } + if err != nil { + return err } - defer tmpl.Reader.Close() - outputFiler, err := constructOutputFiler(ctx, outputDir) - if err != nil { - return err - } - - tmpl.Writer.Initialize(tmpl.Reader, configFile, outputFiler) - - err = tmpl.Writer.Materialize(ctx) - if err != nil { - return err - } - - return tmpl.Writer.LogTelemetry(ctx) + return tmpl.Writer.Materialize(ctx, tmpl.Reader) } return cmd } diff --git a/libs/template/reader.go b/libs/template/reader.go index 6cfaf9cb6..19d4ec243 100644 --- a/libs/template/reader.go +++ b/libs/template/reader.go @@ -21,12 +21,10 @@ type Reader interface { // Close releases any resources associated with the reader // like cleaning up temporary directories. Close() error - - Name() string } type builtinReader struct { - name string + name TemplateName fsCached fs.FS } @@ -43,7 +41,7 @@ func (r *builtinReader) FS(ctx context.Context) (fs.FS, error) { var templateFS fs.FS for _, entry := range builtin { - if entry.Name == r.name { + if entry.Name == string(r.name) { templateFS = entry.FS break } @@ -57,13 +55,7 @@ func (r *builtinReader) Close() error { return nil } -func (r *builtinReader) Name() string { - return r.name -} - type gitReader struct { - name string - // URL of the git repository that contains the template gitUrl string // tag or branch to checkout ref string @@ -88,7 +80,7 @@ var gitUrlPrefixes = []string{ "git@", } -// TODO: Copy over tests for this function. +// TODO: Make private? func IsGitRepoUrl(url string) bool { result := false for _, prefix := range gitUrlPrefixes { @@ -100,16 +92,6 @@ func IsGitRepoUrl(url string) bool { return result } -// TODO: Can I remove the name from here and other readers? -func NewGitReader(name, gitUrl, ref, templateDir string) Reader { - return &gitReader{ - name: name, - gitUrl: gitUrl, - ref: ref, - templateDir: templateDir, - } -} - // TODO: Test the idempotency of this function as well. func (r *gitReader) FS(ctx context.Context) (fs.FS, error) { // If the FS has already been loaded, return it. @@ -147,10 +129,6 @@ func (r *gitReader) Close() error { return os.RemoveAll(r.tmpRepoDir) } -func (r *gitReader) Name() string { - return r.name -} - type localReader struct { name string // Path on the local filesystem that contains the template @@ -159,13 +137,6 @@ type localReader struct { fsCached fs.FS } -func NewLocalReader(name, path string) Reader { - return &localReader{ - name: name, - path: path, - } -} - func (r *localReader) FS(ctx context.Context) (fs.FS, error) { // If the FS has already been loaded, return it. if r.fsCached != nil { @@ -180,10 +151,6 @@ func (r *localReader) Close() error { return nil } -func (r *localReader) Name() string { - return r.name -} - type failReader struct{} func (r *failReader) FS(ctx context.Context) (fs.FS, error) { @@ -193,7 +160,3 @@ func (r *failReader) FS(ctx context.Context) (fs.FS, error) { func (r *failReader) Close() error { return fmt.Errorf("this is a placeholder reader that always fails. Please configure a real reader.") } - -func (r *failReader) Name() string { - return "failReader" -} diff --git a/libs/template/resolve.go b/libs/template/resolve.go new file mode 100644 index 000000000..a4099e275 --- /dev/null +++ b/libs/template/resolve.go @@ -0,0 +1,70 @@ +package template + +import ( + "context" + "errors" +) + +type Resolver struct { + TemplatePathOrUrl string + ConfigFile string + OutputDir string + TemplateDir string + Tag string + Branch string +} + +var ErrCustomSelected = errors.New("custom template selected") + +// Configures the reader and the writer for template and returns +// a handle to the template. +// Prompts the user if needed. +func (r Resolver) Resolve(ctx context.Context) (*Template, error) { + if r.Tag != "" && r.Branch != "" { + return nil, errors.New("only one of --tag or --branch can be specified") + } + + // Git ref to use for template initialization + ref := r.Branch + if r.Tag != "" { + ref = r.Tag + } + + var tmpl *Template + if r.TemplatePathOrUrl == "" { + // Prompt the user to select a template + // if a template path or URL is not provided. + tmplId, err := SelectTemplate(ctx) + if err != nil { + return nil, err + } + + if tmplId == Custom { + return nil, ErrCustomSelected + } + + tmpl = Get(tmplId) + } else { + // Based on the provided template path or URL, + // configure a reader for the template. + tmpl = Get(Custom) + if IsGitRepoUrl(r.TemplatePathOrUrl) { + tmpl.Reader = &gitReader{ + gitUrl: r.TemplatePathOrUrl, + ref: ref, + templateDir: r.TemplateDir, + } + } else { + tmpl.Reader = &localReader{ + path: r.TemplatePathOrUrl, + } + } + } + + err := tmpl.Writer.Configure(ctx, r.ConfigFile, r.OutputDir) + if err != nil { + return nil, err + } + + return tmpl, nil +} diff --git a/libs/template/template.go b/libs/template/template.go index 1467ff2e5..ec8e1ac15 100644 --- a/libs/template/template.go +++ b/libs/template/template.go @@ -6,14 +6,14 @@ import ( "strings" "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/filer" ) type Template struct { + // TODO: Make private as much as possible. Reader Reader Writer Writer - Id TemplateId + Name TemplateName Description string Aliases []string Hidden bool @@ -30,52 +30,52 @@ type NativeTemplate struct { IsOwnedByDatabricks bool } -type TemplateId string +type TemplateName string const ( - DefaultPython TemplateId = "default-python" - DefaultSql TemplateId = "default-sql" - DbtSql TemplateId = "dbt-sql" - MlopsStacks TemplateId = "mlops-stacks" - DefaultPydabs TemplateId = "default-pydabs" - Custom TemplateId = "custom" + DefaultPython TemplateName = "default-python" + DefaultSql TemplateName = "default-sql" + DbtSql TemplateName = "dbt-sql" + MlopsStacks TemplateName = "mlops-stacks" + DefaultPydabs TemplateName = "default-pydabs" + Custom TemplateName = "custom" ) var allTemplates = []Template{ { - Id: DefaultPython, + Name: DefaultPython, Description: "The default Python template for Notebooks / Delta Live Tables / Workflows", Reader: &builtinReader{name: "default-python"}, Writer: &writerWithTelemetry{}, }, { - Id: DefaultSql, + Name: DefaultSql, Description: "The default SQL template for .sql files that run with Databricks SQL", Reader: &builtinReader{name: "default-sql"}, Writer: &writerWithTelemetry{}, }, { - Id: DbtSql, + Name: DbtSql, Description: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)", Reader: &builtinReader{name: "dbt-sql"}, Writer: &writerWithTelemetry{}, }, { - Id: MlopsStacks, + Name: MlopsStacks, Description: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)", Aliases: []string{"mlops-stack"}, Reader: &gitReader{gitUrl: "https://github.com/databricks/mlops-stacks"}, Writer: &writerWithTelemetry{}, }, { - Id: DefaultPydabs, + Name: DefaultPydabs, Hidden: true, Description: "The default PyDABs template", Reader: &gitReader{gitUrl: "https://databricks.github.io/workflows-authoring-toolkit/pydabs-template.git"}, Writer: &writerWithTelemetry{}, }, { - Id: Custom, + Name: Custom, Description: "Bring your own template", Reader: &failReader{}, Writer: &defaultWriter{}, @@ -85,8 +85,8 @@ var allTemplates = []Template{ func HelpDescriptions() string { var lines []string for _, template := range allTemplates { - if template.Id != Custom && !template.Hidden { - lines = append(lines, fmt.Sprintf("- %s: %s", template.Id, template.Description)) + if template.Name != Custom && !template.Hidden { + lines = append(lines, fmt.Sprintf("- %s: %s", template.Name, template.Description)) } } return strings.Join(lines, "\n") @@ -99,7 +99,7 @@ func options() []cmdio.Tuple { continue } tuple := cmdio.Tuple{ - Name: string(template.Id), + Name: string(template.Name), Id: template.Description, } names = append(names, tuple) @@ -108,7 +108,7 @@ func options() []cmdio.Tuple { } // TODO CONTINUE defining the methods that the init command will finally rely on. -func PromptForTemplateId(ctx context.Context, ref, templateDir string) (TemplateId, error) { +func SelectTemplate(ctx context.Context) (TemplateName, error) { if !cmdio.IsPromptSupported(ctx) { return "", fmt.Errorf("please specify a template") } @@ -119,24 +119,16 @@ func PromptForTemplateId(ctx context.Context, ref, templateDir string) (Template for _, template := range allTemplates { if template.Description == description { - return template.Id, nil + return template.Name, nil } } panic("this should never happen - template not found") } -func (tmpl *Template) InitializeWriter(configPath string, outputFiler filer.Filer) { - tmpl.Writer.Initialize(tmpl.Reader, configPath, outputFiler) -} - -func (tmpl *Template) SetReader(r Reader) { - tmpl.Reader = r -} - -func Get(id TemplateId) *Template { +func Get(name TemplateName) *Template { for _, template := range allTemplates { - if template.Id == id { + if template.Name == name { return &template } } diff --git a/libs/template/writer.go b/libs/template/writer.go index 29a5deb69..b0ec1ad46 100644 --- a/libs/template/writer.go +++ b/libs/template/writer.go @@ -5,8 +5,12 @@ import ( "errors" "fmt" "io/fs" + "path/filepath" + "strings" + "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/dbr" "github.com/databricks/cli/libs/filer" ) @@ -30,7 +34,6 @@ import ( // assert.EqualError(t, err, fmt.Sprintf("not a bundle template: expected to find a template schema file at %s", schemaFileName)) // } - // TODO: Add tests for these writers, mocking the cmdio library // at the same time. const ( @@ -40,13 +43,12 @@ const ( ) type Writer interface { - Initialize(reader Reader, configPath string, outputFiler filer.Filer) - Materialize(ctx context.Context) error + Configure(ctx context.Context, configPath, outputDir string) error + Materialize(ctx context.Context, r Reader) error LogTelemetry(ctx context.Context) error } type defaultWriter struct { - reader Reader configPath string outputFiler filer.Filer @@ -55,13 +57,40 @@ type defaultWriter struct { renderer *renderer } -func (tmpl *defaultWriter) Initialize(reader Reader, configPath string, outputFiler filer.Filer) { - tmpl.configPath = configPath - tmpl.outputFiler = outputFiler +func constructOutputFiler(ctx context.Context, outputDir string) (filer.Filer, error) { + outputDir, err := filepath.Abs(outputDir) + if err != nil { + return nil, err + } + + // If the CLI is running on DBR and we're writing to the workspace file system, + // use the extension-aware workspace filesystem filer to instantiate the template. + // + // It is not possible to write notebooks through the workspace filesystem's FUSE mount. + // Therefore this is the only way we can initialize templates that contain notebooks + // when running the CLI on DBR and initializing a template to the workspace. + // + if strings.HasPrefix(outputDir, "/Workspace/") && dbr.RunsOnRuntime(ctx) { + return filer.NewWorkspaceFilesExtensionsClient(root.WorkspaceClient(ctx), outputDir) + } + + return filer.NewLocalClient(outputDir) } -func (tmpl *defaultWriter) promptForInput(ctx context.Context) error { - readerFs, err := tmpl.reader.FS(ctx) +func (tmpl *defaultWriter) Configure(ctx context.Context, configPath string, outputDir string) error { + tmpl.configPath = configPath + + outputFiler, err := constructOutputFiler(ctx, outputDir) + if err != nil { + return err + } + + tmpl.outputFiler = outputFiler + return nil +} + +func (tmpl *defaultWriter) promptForInput(ctx context.Context, reader Reader) error { + readerFs, err := reader.FS(ctx) if err != nil { return err } @@ -122,8 +151,8 @@ func (tmpl *defaultWriter) printSuccessMessage(ctx context.Context) error { return nil } -func (tmpl *defaultWriter) Materialize(ctx context.Context) error { - err := tmpl.promptForInput(ctx) +func (tmpl *defaultWriter) Materialize(ctx context.Context, reader Reader) error { + err := tmpl.promptForInput(ctx, reader) if err != nil { return err } @@ -157,13 +186,3 @@ func (tmpl *writerWithTelemetry) LogTelemetry(ctx context.Context) error { // Log telemetry. TODO. return nil } - -func NewWriterWithTelemetry(reader Reader, configPath string, outputFiler filer.Filer) Writer { - return &writerWithTelemetry{ - defaultWriter: defaultWriter{ - reader: reader, - configPath: configPath, - outputFiler: outputFiler, - }, - } -}