add resolver

This commit is contained in:
Shreyas Goenka 2025-01-03 17:28:22 +05:30
parent a743139f83
commit d238dd833c
No known key found for this signature in database
GPG Key ID: 92A07DF49CCB0622
5 changed files with 151 additions and 165 deletions

View File

@ -1,40 +1,15 @@
package bundle package bundle
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"path/filepath"
"strings"
"github.com/databricks/cli/cmd/root" "github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/dbr"
"github.com/databricks/cli/libs/filer"
"github.com/databricks/cli/libs/template" "github.com/databricks/cli/libs/template"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
func constructOutputFiler(ctx context.Context, outputDir string) (filer.Filer, error) {
outputDir, err := filepath.Abs(outputDir)
if err != nil {
return nil, err
}
// If the CLI is running on DBR and we're writing to the workspace file system,
// use the extension-aware workspace filesystem filer to instantiate the template.
//
// It is not possible to write notebooks through the workspace filesystem's FUSE mount.
// Therefore this is the only way we can initialize templates that contain notebooks
// when running the CLI on DBR and initializing a template to the workspace.
//
if strings.HasPrefix(outputDir, "/Workspace/") && dbr.RunsOnRuntime(ctx) {
return filer.NewWorkspaceFilesExtensionsClient(root.WorkspaceClient(ctx), outputDir)
}
return filer.NewLocalClient(outputDir)
}
func newInitCommand() *cobra.Command { func newInitCommand() *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "init [TEMPLATE_PATH]", Use: "init [TEMPLATE_PATH]",
@ -62,35 +37,18 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf
cmd.Flags().StringVar(&tag, "branch", "", "Git branch to use for template initialization") cmd.Flags().StringVar(&tag, "branch", "", "Git branch to use for template initialization")
cmd.RunE = func(cmd *cobra.Command, args []string) error { cmd.RunE = func(cmd *cobra.Command, args []string) error {
if tag != "" && branch != "" { r := template.Resolver{
return errors.New("only one of --tag or --branch can be specified") TemplatePathOrUrl: args[0],
ConfigFile: configFile,
OutputDir: outputDir,
TemplateDir: templateDir,
Tag: tag,
Branch: branch,
} }
// Git ref to use for template initialization
ref := branch
if tag != "" {
ref = tag
}
var tmpl *template.Template
var err error
ctx := cmd.Context() ctx := cmd.Context()
tmpl, err := r.Resolve(ctx)
if len(args) > 0 { if errors.Is(err, template.ErrCustomSelected) {
// User already specified a template local path or a Git URL. Use that
// information to configure a reader for the template
tmpl = template.Get(template.Custom)
// TODO: Get rid of the name arg.
if template.IsGitRepoUrl(args[0]) {
tmpl.SetReader(template.NewGitReader("", args[0], ref, templateDir))
} else {
tmpl.SetReader(template.NewLocalReader("", args[0]))
}
} else {
tmplId, err := template.PromptForTemplateId(cmd.Context(), ref, templateDir)
if tmplId == template.Custom {
// If a user selects custom during the prompt, ask them to provide a path or Git URL
// as a positional argument.
cmdio.LogString(ctx, "Please specify a path or Git repository to use a custom template.") cmdio.LogString(ctx, "Please specify a path or Git repository to use a custom template.")
cmdio.LogString(ctx, "See https://docs.databricks.com/en/dev-tools/bundles/templates.html to learn more about custom templates.") cmdio.LogString(ctx, "See https://docs.databricks.com/en/dev-tools/bundles/templates.html to learn more about custom templates.")
return nil return nil
@ -98,25 +56,9 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf
if err != nil { if err != nil {
return err return err
} }
tmpl = template.Get(tmplId)
}
defer tmpl.Reader.Close() defer tmpl.Reader.Close()
outputFiler, err := constructOutputFiler(ctx, outputDir) return tmpl.Writer.Materialize(ctx, tmpl.Reader)
if err != nil {
return err
}
tmpl.Writer.Initialize(tmpl.Reader, configFile, outputFiler)
err = tmpl.Writer.Materialize(ctx)
if err != nil {
return err
}
return tmpl.Writer.LogTelemetry(ctx)
} }
return cmd return cmd
} }

View File

@ -21,12 +21,10 @@ type Reader interface {
// Close releases any resources associated with the reader // Close releases any resources associated with the reader
// like cleaning up temporary directories. // like cleaning up temporary directories.
Close() error Close() error
Name() string
} }
type builtinReader struct { type builtinReader struct {
name string name TemplateName
fsCached fs.FS fsCached fs.FS
} }
@ -43,7 +41,7 @@ func (r *builtinReader) FS(ctx context.Context) (fs.FS, error) {
var templateFS fs.FS var templateFS fs.FS
for _, entry := range builtin { for _, entry := range builtin {
if entry.Name == r.name { if entry.Name == string(r.name) {
templateFS = entry.FS templateFS = entry.FS
break break
} }
@ -57,13 +55,7 @@ func (r *builtinReader) Close() error {
return nil return nil
} }
func (r *builtinReader) Name() string {
return r.name
}
type gitReader struct { type gitReader struct {
name string
// URL of the git repository that contains the template
gitUrl string gitUrl string
// tag or branch to checkout // tag or branch to checkout
ref string ref string
@ -88,7 +80,7 @@ var gitUrlPrefixes = []string{
"git@", "git@",
} }
// TODO: Copy over tests for this function. // TODO: Make private?
func IsGitRepoUrl(url string) bool { func IsGitRepoUrl(url string) bool {
result := false result := false
for _, prefix := range gitUrlPrefixes { for _, prefix := range gitUrlPrefixes {
@ -100,16 +92,6 @@ func IsGitRepoUrl(url string) bool {
return result return result
} }
// TODO: Can I remove the name from here and other readers?
func NewGitReader(name, gitUrl, ref, templateDir string) Reader {
return &gitReader{
name: name,
gitUrl: gitUrl,
ref: ref,
templateDir: templateDir,
}
}
// TODO: Test the idempotency of this function as well. // TODO: Test the idempotency of this function as well.
func (r *gitReader) FS(ctx context.Context) (fs.FS, error) { func (r *gitReader) FS(ctx context.Context) (fs.FS, error) {
// If the FS has already been loaded, return it. // If the FS has already been loaded, return it.
@ -147,10 +129,6 @@ func (r *gitReader) Close() error {
return os.RemoveAll(r.tmpRepoDir) return os.RemoveAll(r.tmpRepoDir)
} }
func (r *gitReader) Name() string {
return r.name
}
type localReader struct { type localReader struct {
name string name string
// Path on the local filesystem that contains the template // Path on the local filesystem that contains the template
@ -159,13 +137,6 @@ type localReader struct {
fsCached fs.FS fsCached fs.FS
} }
func NewLocalReader(name, path string) Reader {
return &localReader{
name: name,
path: path,
}
}
func (r *localReader) FS(ctx context.Context) (fs.FS, error) { func (r *localReader) FS(ctx context.Context) (fs.FS, error) {
// If the FS has already been loaded, return it. // If the FS has already been loaded, return it.
if r.fsCached != nil { if r.fsCached != nil {
@ -180,10 +151,6 @@ func (r *localReader) Close() error {
return nil return nil
} }
func (r *localReader) Name() string {
return r.name
}
type failReader struct{} type failReader struct{}
func (r *failReader) FS(ctx context.Context) (fs.FS, error) { func (r *failReader) FS(ctx context.Context) (fs.FS, error) {
@ -193,7 +160,3 @@ func (r *failReader) FS(ctx context.Context) (fs.FS, error) {
func (r *failReader) Close() error { func (r *failReader) Close() error {
return fmt.Errorf("this is a placeholder reader that always fails. Please configure a real reader.") return fmt.Errorf("this is a placeholder reader that always fails. Please configure a real reader.")
} }
func (r *failReader) Name() string {
return "failReader"
}

70
libs/template/resolve.go Normal file
View File

@ -0,0 +1,70 @@
package template
import (
"context"
"errors"
)
type Resolver struct {
TemplatePathOrUrl string
ConfigFile string
OutputDir string
TemplateDir string
Tag string
Branch string
}
var ErrCustomSelected = errors.New("custom template selected")
// Configures the reader and the writer for template and returns
// a handle to the template.
// Prompts the user if needed.
func (r Resolver) Resolve(ctx context.Context) (*Template, error) {
if r.Tag != "" && r.Branch != "" {
return nil, errors.New("only one of --tag or --branch can be specified")
}
// Git ref to use for template initialization
ref := r.Branch
if r.Tag != "" {
ref = r.Tag
}
var tmpl *Template
if r.TemplatePathOrUrl == "" {
// Prompt the user to select a template
// if a template path or URL is not provided.
tmplId, err := SelectTemplate(ctx)
if err != nil {
return nil, err
}
if tmplId == Custom {
return nil, ErrCustomSelected
}
tmpl = Get(tmplId)
} else {
// Based on the provided template path or URL,
// configure a reader for the template.
tmpl = Get(Custom)
if IsGitRepoUrl(r.TemplatePathOrUrl) {
tmpl.Reader = &gitReader{
gitUrl: r.TemplatePathOrUrl,
ref: ref,
templateDir: r.TemplateDir,
}
} else {
tmpl.Reader = &localReader{
path: r.TemplatePathOrUrl,
}
}
}
err := tmpl.Writer.Configure(ctx, r.ConfigFile, r.OutputDir)
if err != nil {
return nil, err
}
return tmpl, nil
}

View File

@ -6,14 +6,14 @@ import (
"strings" "strings"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/filer"
) )
type Template struct { type Template struct {
// TODO: Make private as much as possible.
Reader Reader Reader Reader
Writer Writer Writer Writer
Id TemplateId Name TemplateName
Description string Description string
Aliases []string Aliases []string
Hidden bool Hidden bool
@ -30,52 +30,52 @@ type NativeTemplate struct {
IsOwnedByDatabricks bool IsOwnedByDatabricks bool
} }
type TemplateId string type TemplateName string
const ( const (
DefaultPython TemplateId = "default-python" DefaultPython TemplateName = "default-python"
DefaultSql TemplateId = "default-sql" DefaultSql TemplateName = "default-sql"
DbtSql TemplateId = "dbt-sql" DbtSql TemplateName = "dbt-sql"
MlopsStacks TemplateId = "mlops-stacks" MlopsStacks TemplateName = "mlops-stacks"
DefaultPydabs TemplateId = "default-pydabs" DefaultPydabs TemplateName = "default-pydabs"
Custom TemplateId = "custom" Custom TemplateName = "custom"
) )
var allTemplates = []Template{ var allTemplates = []Template{
{ {
Id: DefaultPython, Name: DefaultPython,
Description: "The default Python template for Notebooks / Delta Live Tables / Workflows", Description: "The default Python template for Notebooks / Delta Live Tables / Workflows",
Reader: &builtinReader{name: "default-python"}, Reader: &builtinReader{name: "default-python"},
Writer: &writerWithTelemetry{}, Writer: &writerWithTelemetry{},
}, },
{ {
Id: DefaultSql, Name: DefaultSql,
Description: "The default SQL template for .sql files that run with Databricks SQL", Description: "The default SQL template for .sql files that run with Databricks SQL",
Reader: &builtinReader{name: "default-sql"}, Reader: &builtinReader{name: "default-sql"},
Writer: &writerWithTelemetry{}, Writer: &writerWithTelemetry{},
}, },
{ {
Id: DbtSql, Name: DbtSql,
Description: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)", Description: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)",
Reader: &builtinReader{name: "dbt-sql"}, Reader: &builtinReader{name: "dbt-sql"},
Writer: &writerWithTelemetry{}, Writer: &writerWithTelemetry{},
}, },
{ {
Id: MlopsStacks, Name: MlopsStacks,
Description: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)", Description: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)",
Aliases: []string{"mlops-stack"}, Aliases: []string{"mlops-stack"},
Reader: &gitReader{gitUrl: "https://github.com/databricks/mlops-stacks"}, Reader: &gitReader{gitUrl: "https://github.com/databricks/mlops-stacks"},
Writer: &writerWithTelemetry{}, Writer: &writerWithTelemetry{},
}, },
{ {
Id: DefaultPydabs, Name: DefaultPydabs,
Hidden: true, Hidden: true,
Description: "The default PyDABs template", Description: "The default PyDABs template",
Reader: &gitReader{gitUrl: "https://databricks.github.io/workflows-authoring-toolkit/pydabs-template.git"}, Reader: &gitReader{gitUrl: "https://databricks.github.io/workflows-authoring-toolkit/pydabs-template.git"},
Writer: &writerWithTelemetry{}, Writer: &writerWithTelemetry{},
}, },
{ {
Id: Custom, Name: Custom,
Description: "Bring your own template", Description: "Bring your own template",
Reader: &failReader{}, Reader: &failReader{},
Writer: &defaultWriter{}, Writer: &defaultWriter{},
@ -85,8 +85,8 @@ var allTemplates = []Template{
func HelpDescriptions() string { func HelpDescriptions() string {
var lines []string var lines []string
for _, template := range allTemplates { for _, template := range allTemplates {
if template.Id != Custom && !template.Hidden { if template.Name != Custom && !template.Hidden {
lines = append(lines, fmt.Sprintf("- %s: %s", template.Id, template.Description)) lines = append(lines, fmt.Sprintf("- %s: %s", template.Name, template.Description))
} }
} }
return strings.Join(lines, "\n") return strings.Join(lines, "\n")
@ -99,7 +99,7 @@ func options() []cmdio.Tuple {
continue continue
} }
tuple := cmdio.Tuple{ tuple := cmdio.Tuple{
Name: string(template.Id), Name: string(template.Name),
Id: template.Description, Id: template.Description,
} }
names = append(names, tuple) names = append(names, tuple)
@ -108,7 +108,7 @@ func options() []cmdio.Tuple {
} }
// TODO CONTINUE defining the methods that the init command will finally rely on. // TODO CONTINUE defining the methods that the init command will finally rely on.
func PromptForTemplateId(ctx context.Context, ref, templateDir string) (TemplateId, error) { func SelectTemplate(ctx context.Context) (TemplateName, error) {
if !cmdio.IsPromptSupported(ctx) { if !cmdio.IsPromptSupported(ctx) {
return "", fmt.Errorf("please specify a template") return "", fmt.Errorf("please specify a template")
} }
@ -119,24 +119,16 @@ func PromptForTemplateId(ctx context.Context, ref, templateDir string) (Template
for _, template := range allTemplates { for _, template := range allTemplates {
if template.Description == description { if template.Description == description {
return template.Id, nil return template.Name, nil
} }
} }
panic("this should never happen - template not found") panic("this should never happen - template not found")
} }
func (tmpl *Template) InitializeWriter(configPath string, outputFiler filer.Filer) { func Get(name TemplateName) *Template {
tmpl.Writer.Initialize(tmpl.Reader, configPath, outputFiler)
}
func (tmpl *Template) SetReader(r Reader) {
tmpl.Reader = r
}
func Get(id TemplateId) *Template {
for _, template := range allTemplates { for _, template := range allTemplates {
if template.Id == id { if template.Name == name {
return &template return &template
} }
} }

View File

@ -5,8 +5,12 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/fs" "io/fs"
"path/filepath"
"strings"
"github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/dbr"
"github.com/databricks/cli/libs/filer" "github.com/databricks/cli/libs/filer"
) )
@ -30,7 +34,6 @@ import (
// assert.EqualError(t, err, fmt.Sprintf("not a bundle template: expected to find a template schema file at %s", schemaFileName)) // assert.EqualError(t, err, fmt.Sprintf("not a bundle template: expected to find a template schema file at %s", schemaFileName))
// } // }
// TODO: Add tests for these writers, mocking the cmdio library // TODO: Add tests for these writers, mocking the cmdio library
// at the same time. // at the same time.
const ( const (
@ -40,13 +43,12 @@ const (
) )
type Writer interface { type Writer interface {
Initialize(reader Reader, configPath string, outputFiler filer.Filer) Configure(ctx context.Context, configPath, outputDir string) error
Materialize(ctx context.Context) error Materialize(ctx context.Context, r Reader) error
LogTelemetry(ctx context.Context) error LogTelemetry(ctx context.Context) error
} }
type defaultWriter struct { type defaultWriter struct {
reader Reader
configPath string configPath string
outputFiler filer.Filer outputFiler filer.Filer
@ -55,13 +57,40 @@ type defaultWriter struct {
renderer *renderer renderer *renderer
} }
func (tmpl *defaultWriter) Initialize(reader Reader, configPath string, outputFiler filer.Filer) { func constructOutputFiler(ctx context.Context, outputDir string) (filer.Filer, error) {
tmpl.configPath = configPath outputDir, err := filepath.Abs(outputDir)
tmpl.outputFiler = outputFiler if err != nil {
return nil, err
}
// If the CLI is running on DBR and we're writing to the workspace file system,
// use the extension-aware workspace filesystem filer to instantiate the template.
//
// It is not possible to write notebooks through the workspace filesystem's FUSE mount.
// Therefore this is the only way we can initialize templates that contain notebooks
// when running the CLI on DBR and initializing a template to the workspace.
//
if strings.HasPrefix(outputDir, "/Workspace/") && dbr.RunsOnRuntime(ctx) {
return filer.NewWorkspaceFilesExtensionsClient(root.WorkspaceClient(ctx), outputDir)
}
return filer.NewLocalClient(outputDir)
} }
func (tmpl *defaultWriter) promptForInput(ctx context.Context) error { func (tmpl *defaultWriter) Configure(ctx context.Context, configPath string, outputDir string) error {
readerFs, err := tmpl.reader.FS(ctx) tmpl.configPath = configPath
outputFiler, err := constructOutputFiler(ctx, outputDir)
if err != nil {
return err
}
tmpl.outputFiler = outputFiler
return nil
}
func (tmpl *defaultWriter) promptForInput(ctx context.Context, reader Reader) error {
readerFs, err := reader.FS(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -122,8 +151,8 @@ func (tmpl *defaultWriter) printSuccessMessage(ctx context.Context) error {
return nil return nil
} }
func (tmpl *defaultWriter) Materialize(ctx context.Context) error { func (tmpl *defaultWriter) Materialize(ctx context.Context, reader Reader) error {
err := tmpl.promptForInput(ctx) err := tmpl.promptForInput(ctx, reader)
if err != nil { if err != nil {
return err return err
} }
@ -157,13 +186,3 @@ func (tmpl *writerWithTelemetry) LogTelemetry(ctx context.Context) error {
// Log telemetry. TODO. // Log telemetry. TODO.
return nil return nil
} }
func NewWriterWithTelemetry(reader Reader, configPath string, outputFiler filer.Filer) Writer {
return &writerWithTelemetry{
defaultWriter: defaultWriter{
reader: reader,
configPath: configPath,
outputFiler: outputFiler,
},
}
}