add unit tests for reader

This commit is contained in:
Shreyas Goenka 2025-01-03 18:29:50 +05:30
parent d238dd833c
commit 2965c30268
No known key found for this signature in database
GPG Key ID: 92A07DF49CCB0622
8 changed files with 190 additions and 84 deletions

View File

@ -8,13 +8,17 @@ import (
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"strings"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/internal/testcli" "github.com/databricks/cli/internal/testcli"
"github.com/databricks/cli/internal/testutil" "github.com/databricks/cli/internal/testutil"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/flags"
"github.com/databricks/cli/libs/folders" "github.com/databricks/cli/libs/folders"
"github.com/databricks/cli/libs/template"
"github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -27,32 +31,28 @@ func initTestTemplate(t testutil.TestingT, ctx context.Context, templateName str
} }
func initTestTemplateWithBundleRoot(t testutil.TestingT, ctx context.Context, templateName string, config map[string]any, bundleRoot string) string { func initTestTemplateWithBundleRoot(t testutil.TestingT, ctx context.Context, templateName string, config map[string]any, bundleRoot string) string {
return "" templateRoot := filepath.Join("bundles", templateName)
// TODO: Make this function work but do not log telemetry. configFilePath := writeConfigFile(t, config)
// templateRoot := filepath.Join("bundles", templateName) ctx = root.SetWorkspaceClient(ctx, nil)
cmd := cmdio.NewIO(ctx, flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "", "bundles")
ctx = cmdio.InContext(ctx, cmd)
// configFilePath := writeConfigFile(t, config) r := template.Resolver{
TemplatePathOrUrl: templateRoot,
ConfigFile: configFilePath,
OutputDir: bundleRoot,
}
// ctx = root.SetWorkspaceClient(ctx, nil) tmpl, err := r.Resolve(ctx)
// cmd := cmdio.NewIO(ctx, flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "", "bundles") require.NoError(t, err)
// ctx = cmdio.InContext(ctx, cmd) defer tmpl.Reader.Close()
// ctx = telemetry.WithMockLogger(ctx)
// out, err := filer.NewLocalClient(bundleRoot) err = tmpl.Writer.Materialize(ctx, tmpl.Reader)
// require.NoError(t, err) require.NoError(t, err)
// tmpl := template.TemplateX{
// TemplateOpts: template.TemplateOpts{
// ConfigFilePath: configFilePath,
// TemplateFS: os.DirFS(templateRoot),
// OutputFiler: out,
// },
// }
// err = tmpl.Materialize(ctx) return bundleRoot
// require.NoError(t, err)
// return bundleRoot
} }
func writeConfigFile(t testutil.TestingT, config map[string]any) string { func writeConfigFile(t testutil.TestingT, config map[string]any) string {

View File

@ -8,15 +8,14 @@ import (
//go:embed all:templates //go:embed all:templates
var builtinTemplates embed.FS var builtinTemplates embed.FS
// BuiltinTemplate represents a template that is built into the CLI. // builtinTemplate represents a template that is built into the CLI.
type BuiltinTemplate struct { type builtinTemplate struct {
Name string Name string
FS fs.FS FS fs.FS
} }
// Builtin returns the list of all built-in templates. // builtin returns the list of all built-in templates.
// TODO: Make private? func builtin() ([]builtinTemplate, error) {
func Builtin() ([]BuiltinTemplate, error) {
templates, err := fs.Sub(builtinTemplates, "templates") templates, err := fs.Sub(builtinTemplates, "templates")
if err != nil { if err != nil {
return nil, err return nil, err
@ -27,7 +26,7 @@ func Builtin() ([]BuiltinTemplate, error) {
return nil, err return nil, err
} }
var out []BuiltinTemplate var out []builtinTemplate
for _, entry := range entries { for _, entry := range entries {
if !entry.IsDir() { if !entry.IsDir() {
continue continue
@ -38,7 +37,7 @@ func Builtin() ([]BuiltinTemplate, error) {
return nil, err return nil, err
} }
out = append(out, BuiltinTemplate{ out = append(out, builtinTemplate{
Name: entry.Name(), Name: entry.Name(),
FS: templateFS, FS: templateFS,
}) })

View File

@ -9,12 +9,12 @@ import (
) )
func TestBuiltin(t *testing.T) { func TestBuiltin(t *testing.T) {
out, err := Builtin() out, err := builtin()
require.NoError(t, err) require.NoError(t, err)
assert.GreaterOrEqual(t, len(out), 3) assert.GreaterOrEqual(t, len(out), 3)
// Create a map of templates by name for easier lookup // Create a map of templates by name for easier lookup
templates := make(map[string]*BuiltinTemplate) templates := make(map[string]*builtinTemplate)
for _, tmpl := range out { for _, tmpl := range out {
templates[tmpl.Name] = &tmpl templates[tmpl.Name] = &tmpl
} }

View File

@ -9,10 +9,8 @@ import (
"strings" "strings"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/git"
) )
// TODO: Add tests for all these readers.
type Reader interface { type Reader interface {
// FS returns a file system that contains the template // FS returns a file system that contains the template
// definition files. This function is NOT thread safe. // definition files. This function is NOT thread safe.
@ -24,7 +22,7 @@ type Reader interface {
} }
type builtinReader struct { type builtinReader struct {
name TemplateName name string
fsCached fs.FS fsCached fs.FS
} }
@ -34,19 +32,23 @@ func (r *builtinReader) FS(ctx context.Context) (fs.FS, error) {
return r.fsCached, nil return r.fsCached, nil
} }
builtin, err := Builtin() builtin, err := builtin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
var templateFS fs.FS var templateFS fs.FS
for _, entry := range builtin { for _, entry := range builtin {
if entry.Name == string(r.name) { if entry.Name == r.name {
templateFS = entry.FS templateFS = entry.FS
break break
} }
} }
if templateFS == nil {
return nil, fmt.Errorf("builtin template %s not found", r.name)
}
r.fsCached = templateFS r.fsCached = templateFS
return r.fsCached, nil return r.fsCached, nil
} }
@ -64,6 +66,10 @@ type gitReader struct {
// temporary directory where the repository is cloned // temporary directory where the repository is cloned
tmpRepoDir string tmpRepoDir string
// Function to clone the repository. This is a function pointer to allow
// mocking in tests.
cloneFunc func(ctx context.Context, url, reference, targetPath string) error
fsCached fs.FS fsCached fs.FS
} }
@ -80,8 +86,7 @@ var gitUrlPrefixes = []string{
"git@", "git@",
} }
// TODO: Make private? func isRepoUrl(url string) bool {
func IsGitRepoUrl(url string) bool {
result := false result := false
for _, prefix := range gitUrlPrefixes { for _, prefix := range gitUrlPrefixes {
if strings.HasPrefix(url, prefix) { if strings.HasPrefix(url, prefix) {
@ -92,7 +97,6 @@ func IsGitRepoUrl(url string) bool {
return result return result
} }
// TODO: Test the idempotency of this function as well.
func (r *gitReader) FS(ctx context.Context) (fs.FS, error) { func (r *gitReader) FS(ctx context.Context) (fs.FS, error) {
// If the FS has already been loaded, return it. // If the FS has already been loaded, return it.
if r.fsCached != nil { if r.fsCached != nil {
@ -111,7 +115,7 @@ func (r *gitReader) FS(ctx context.Context) (fs.FS, error) {
promptSpinner := cmdio.Spinner(ctx) promptSpinner := cmdio.Spinner(ctx)
promptSpinner <- "Downloading the template\n" promptSpinner <- "Downloading the template\n"
err = git.Clone(ctx, r.gitUrl, r.ref, repoDir) err = r.cloneFunc(ctx, r.gitUrl, r.ref, repoDir)
close(promptSpinner) close(promptSpinner)
if err != nil { if err != nil {
return nil, err return nil, err
@ -130,7 +134,6 @@ func (r *gitReader) Close() error {
} }
type localReader struct { type localReader struct {
name string
// Path on the local filesystem that contains the template // Path on the local filesystem that contains the template
path string path string

View File

@ -0,0 +1,112 @@
package template
import (
"context"
"io"
"os"
"path/filepath"
"strings"
"testing"
"github.com/databricks/cli/internal/testutil"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/flags"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBuiltInReader(t *testing.T) {
exists := []string{
"default-python",
"default-sql",
"dbt-sql",
}
for _, name := range exists {
r := &builtinReader{name: name}
fs, err := r.FS(context.Background())
assert.NoError(t, err)
assert.NotNil(t, fs)
}
// TODO: Read one of the files to confirm further test this reader.
r := &builtinReader{name: "doesnotexist"}
_, err := r.FS(context.Background())
assert.EqualError(t, err, "builtin template doesnotexist not found")
}
func TestGitUrlReader(t *testing.T) {
ctx := context.Background()
cmd := cmdio.NewIO(ctx, flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "", "bundles")
ctx = cmdio.InContext(ctx, cmd)
var args []string
numCalls := 0
cloneFunc := func(ctx context.Context, url, reference, targetPath string) error {
numCalls++
args = []string{url, reference, targetPath}
err := os.MkdirAll(filepath.Join(targetPath, "a/b/c"), 0o755)
require.NoError(t, err)
testutil.WriteFile(t, filepath.Join(targetPath, "a", "b", "c", "somefile"), "somecontent")
return nil
}
r := &gitReader{
gitUrl: "someurl",
cloneFunc: cloneFunc,
ref: "sometag",
templateDir: "a/b/c",
}
// Assert cloneFunc is called with the correct args.
fs, err := r.FS(ctx)
require.NoError(t, err)
require.NotEmpty(t, r.tmpRepoDir)
assert.DirExists(t, r.tmpRepoDir)
assert.Equal(t, []string{"someurl", "sometag", r.tmpRepoDir}, args)
// Assert the fs returned is rooted at the templateDir.
fd, err := fs.Open("somefile")
require.NoError(t, err)
defer fd.Close()
b, err := io.ReadAll(fd)
require.NoError(t, err)
assert.Equal(t, "somecontent", string(b))
// Assert the FS is cached. cloneFunc should not be called again.
_, err = r.FS(ctx)
require.NoError(t, err)
assert.Equal(t, 1, numCalls)
// Assert Close cleans up the tmpRepoDir.
err = r.Close()
require.NoError(t, err)
assert.NoDirExists(t, r.tmpRepoDir)
}
func TestLocalReader(t *testing.T) {
tmpDir := t.TempDir()
testutil.WriteFile(t, filepath.Join(tmpDir, "somefile"), "somecontent")
ctx := context.Background()
r := &localReader{path: tmpDir}
fs, err := r.FS(ctx)
require.NoError(t, err)
// Assert the fs returned is rooted at correct location.
fd, err := fs.Open("somefile")
require.NoError(t, err)
defer fd.Close()
b, err := io.ReadAll(fd)
require.NoError(t, err)
assert.Equal(t, "somecontent", string(b))
// Assert close does not error
assert.NoError(t, r.Close())
}
func TestFailReader(t *testing.T) {
r := &failReader{}
assert.Error(t, r.Close())
_, err := r.FS(context.Background())
assert.Error(t, err)
}

View File

@ -3,6 +3,8 @@ package template
import ( import (
"context" "context"
"errors" "errors"
"github.com/databricks/cli/libs/git"
) )
type Resolver struct { type Resolver struct {
@ -48,11 +50,12 @@ func (r Resolver) Resolve(ctx context.Context) (*Template, error) {
// Based on the provided template path or URL, // Based on the provided template path or URL,
// configure a reader for the template. // configure a reader for the template.
tmpl = Get(Custom) tmpl = Get(Custom)
if IsGitRepoUrl(r.TemplatePathOrUrl) { if isRepoUrl(r.TemplatePathOrUrl) {
tmpl.Reader = &gitReader{ tmpl.Reader = &gitReader{
gitUrl: r.TemplatePathOrUrl, gitUrl: r.TemplatePathOrUrl,
ref: ref, ref: ref,
templateDir: r.TemplateDir, templateDir: r.TemplateDir,
cloneFunc: git.Clone,
} }
} else { } else {
tmpl.Reader = &localReader{ tmpl.Reader = &localReader{

View File

@ -6,28 +6,17 @@ import (
"strings" "strings"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/git"
) )
type Template struct { type Template struct {
// TODO: Make private as much as possible.
Reader Reader Reader Reader
Writer Writer Writer Writer
Name TemplateName name TemplateName
Description string description string
Aliases []string aliases []string
Hidden bool hidden bool
}
// TODO: Make details private?
// TODO: Combine this with the generic template struct?
type NativeTemplate struct {
Name string
Description string
Aliases []string
GitUrl string
Hidden bool
IsOwnedByDatabricks bool
} }
type TemplateName string type TemplateName string
@ -43,40 +32,40 @@ const (
var allTemplates = []Template{ var allTemplates = []Template{
{ {
Name: DefaultPython, name: DefaultPython,
Description: "The default Python template for Notebooks / Delta Live Tables / Workflows", description: "The default Python template for Notebooks / Delta Live Tables / Workflows",
Reader: &builtinReader{name: "default-python"}, Reader: &builtinReader{name: "default-python"},
Writer: &writerWithTelemetry{}, Writer: &writerWithTelemetry{},
}, },
{ {
Name: DefaultSql, name: DefaultSql,
Description: "The default SQL template for .sql files that run with Databricks SQL", description: "The default SQL template for .sql files that run with Databricks SQL",
Reader: &builtinReader{name: "default-sql"}, Reader: &builtinReader{name: "default-sql"},
Writer: &writerWithTelemetry{}, Writer: &writerWithTelemetry{},
}, },
{ {
Name: DbtSql, name: DbtSql,
Description: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)", description: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)",
Reader: &builtinReader{name: "dbt-sql"}, Reader: &builtinReader{name: "dbt-sql"},
Writer: &writerWithTelemetry{}, Writer: &writerWithTelemetry{},
}, },
{ {
Name: MlopsStacks, name: MlopsStacks,
Description: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)", description: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)",
Aliases: []string{"mlops-stack"}, aliases: []string{"mlops-stack"},
Reader: &gitReader{gitUrl: "https://github.com/databricks/mlops-stacks"}, Reader: &gitReader{gitUrl: "https://github.com/databricks/mlops-stacks", cloneFunc: git.Clone},
Writer: &writerWithTelemetry{}, Writer: &writerWithTelemetry{},
}, },
{ {
Name: DefaultPydabs, name: DefaultPydabs,
Hidden: true, hidden: true,
Description: "The default PyDABs template", description: "The default PyDABs template",
Reader: &gitReader{gitUrl: "https://databricks.github.io/workflows-authoring-toolkit/pydabs-template.git"}, Reader: &gitReader{gitUrl: "https://databricks.github.io/workflows-authoring-toolkit/pydabs-template.git", cloneFunc: git.Clone},
Writer: &writerWithTelemetry{}, Writer: &writerWithTelemetry{},
}, },
{ {
Name: Custom, name: Custom,
Description: "Bring your own template", description: "Bring your own template",
Reader: &failReader{}, Reader: &failReader{},
Writer: &defaultWriter{}, Writer: &defaultWriter{},
}, },
@ -85,8 +74,8 @@ var allTemplates = []Template{
func HelpDescriptions() string { func HelpDescriptions() string {
var lines []string var lines []string
for _, template := range allTemplates { for _, template := range allTemplates {
if template.Name != Custom && !template.Hidden { if template.name != Custom && !template.hidden {
lines = append(lines, fmt.Sprintf("- %s: %s", template.Name, template.Description)) lines = append(lines, fmt.Sprintf("- %s: %s", template.name, template.description))
} }
} }
return strings.Join(lines, "\n") return strings.Join(lines, "\n")
@ -95,12 +84,12 @@ func HelpDescriptions() string {
func options() []cmdio.Tuple { func options() []cmdio.Tuple {
names := make([]cmdio.Tuple, 0, len(allTemplates)) names := make([]cmdio.Tuple, 0, len(allTemplates))
for _, template := range allTemplates { for _, template := range allTemplates {
if template.Hidden { if template.hidden {
continue continue
} }
tuple := cmdio.Tuple{ tuple := cmdio.Tuple{
Name: string(template.Name), Name: string(template.name),
Id: template.Description, Id: template.description,
} }
names = append(names, tuple) names = append(names, tuple)
} }
@ -118,8 +107,8 @@ func SelectTemplate(ctx context.Context) (TemplateName, error) {
} }
for _, template := range allTemplates { for _, template := range allTemplates {
if template.Description == description { if template.description == description {
return template.Name, nil return template.name, nil
} }
} }
@ -128,7 +117,7 @@ func SelectTemplate(ctx context.Context) (TemplateName, error) {
func Get(name TemplateName) *Template { func Get(name TemplateName) *Template {
for _, template := range allTemplates { for _, template := range allTemplates {
if template.Name == name { if template.name == name {
return &template return &template
} }
} }

View File

@ -27,11 +27,11 @@ func TestTemplateOptions(t *testing.T) {
} }
func TestBundleInitIsRepoUrl(t *testing.T) { func TestBundleInitIsRepoUrl(t *testing.T) {
assert.True(t, IsGitRepoUrl("git@github.com:databricks/cli.git")) assert.True(t, isRepoUrl("git@github.com:databricks/cli.git"))
assert.True(t, IsGitRepoUrl("https://github.com/databricks/cli.git")) assert.True(t, isRepoUrl("https://github.com/databricks/cli.git"))
assert.False(t, IsGitRepoUrl("./local")) assert.False(t, isRepoUrl("./local"))
assert.False(t, IsGitRepoUrl("foo")) assert.False(t, isRepoUrl("foo"))
} }
func TestBundleInitRepoName(t *testing.T) { func TestBundleInitRepoName(t *testing.T) {