diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fcbab8ce..867e086be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,142 @@ # Version changelog +## 0.205.0 + +This release marks the public preview phase of Databricks Asset Bundles. + +For more information, please refer to our online documentation at +https://docs.databricks.com/en/dev-tools/bundles/. + +CLI: + * Prompt once for a client profile ([#727](https://github.com/databricks/cli/pull/727)). + +Bundles: + * Use clearer error message when no interpolation value is found. ([#764](https://github.com/databricks/cli/pull/764)). + * Use interactive prompt to select resource to run if not specified ([#762](https://github.com/databricks/cli/pull/762)). + * Add documentation link bundle command group description ([#770](https://github.com/databricks/cli/pull/770)). + + +## 0.204.1 + +Bundles: + * Fix conversion of job parameters ([#744](https://github.com/databricks/cli/pull/744)). + * Add schema and config validation to jsonschema package ([#740](https://github.com/databricks/cli/pull/740)). + * Support Model Serving Endpoints in bundles ([#682](https://github.com/databricks/cli/pull/682)). + * Do not include empty output in job run output ([#749](https://github.com/databricks/cli/pull/749)). + * Fixed marking libraries from DBFS as remote ([#750](https://github.com/databricks/cli/pull/750)). + * Process only Python wheel tasks which have local libraries used ([#751](https://github.com/databricks/cli/pull/751)). + * Add enum support for bundle templates ([#668](https://github.com/databricks/cli/pull/668)). + * Apply Python wheel trampoline if workspace library is used ([#755](https://github.com/databricks/cli/pull/755)). + * List available targets when incorrect target passed ([#756](https://github.com/databricks/cli/pull/756)). + * Make bundle and sync fields optional ([#757](https://github.com/databricks/cli/pull/757)). + * Consolidate environment variable interaction ([#747](https://github.com/databricks/cli/pull/747)). + +Internal: + * Update Go SDK to v0.19.1 ([#759](https://github.com/databricks/cli/pull/759)). + + + +## 0.204.0 + +This release includes permission related commands for a subset of workspace +services where they apply. These complement the `permissions` command and +do not require specification of the object type to work with, as that is +implied by the command they are nested under. + +CLI: + * Group permission related commands ([#730](https://github.com/databricks/cli/pull/730)). + +Bundles: + * Fixed artifact file uploading on Windows and wheel execution on DBR 13.3 ([#722](https://github.com/databricks/cli/pull/722)). + * Make resource and artifact paths in bundle config relative to config folder ([#708](https://github.com/databricks/cli/pull/708)). + * Add support for ordering of input prompts ([#662](https://github.com/databricks/cli/pull/662)). + * Fix IsServicePrincipal() only working for workspace admins ([#732](https://github.com/databricks/cli/pull/732)). + * databricks bundle init template v1 ([#686](https://github.com/databricks/cli/pull/686)). + * databricks bundle init template v2: optional stubs, DLT support ([#700](https://github.com/databricks/cli/pull/700)). + * Show 'databricks bundle init' template in CLI prompt ([#725](https://github.com/databricks/cli/pull/725)). + * Include in set of environment variables to pass along. ([#736](https://github.com/databricks/cli/pull/736)). + +Internal: + * Update Go SDK to v0.19.0 ([#729](https://github.com/databricks/cli/pull/729)). + * Replace API call to test configuration with dummy authenticate call ([#728](https://github.com/databricks/cli/pull/728)). + +API Changes: + * Changed `databricks account storage-credentials create` command to return . + * Changed `databricks account storage-credentials get` command to return . + * Changed `databricks account storage-credentials list` command to return . + * Changed `databricks account storage-credentials update` command to return . + * Changed `databricks connections create` command with new required argument order. + * Changed `databricks connections update` command with new required argument order. + * Changed `databricks volumes create` command with new required argument order. + * Added `databricks artifact-allowlists` command group. + * Added `databricks model-versions` command group. + * Added `databricks registered-models` command group. + * Added `databricks cluster-policies get-permission-levels` command. + * Added `databricks cluster-policies get-permissions` command. + * Added `databricks cluster-policies set-permissions` command. + * Added `databricks cluster-policies update-permissions` command. + * Added `databricks clusters get-permission-levels` command. + * Added `databricks clusters get-permissions` command. + * Added `databricks clusters set-permissions` command. + * Added `databricks clusters update-permissions` command. + * Added `databricks instance-pools get-permission-levels` command. + * Added `databricks instance-pools get-permissions` command. + * Added `databricks instance-pools set-permissions` command. + * Added `databricks instance-pools update-permissions` command. + * Added `databricks files` command group. + * Changed `databricks permissions set` command to start returning . + * Changed `databricks permissions update` command to start returning . + * Added `databricks users get-permission-levels` command. + * Added `databricks users get-permissions` command. + * Added `databricks users set-permissions` command. + * Added `databricks users update-permissions` command. + * Added `databricks jobs get-permission-levels` command. + * Added `databricks jobs get-permissions` command. + * Added `databricks jobs set-permissions` command. + * Added `databricks jobs update-permissions` command. + * Changed `databricks experiments get-by-name` command to return . + * Changed `databricks experiments get-experiment` command to return . + * Added `databricks experiments delete-runs` command. + * Added `databricks experiments get-permission-levels` command. + * Added `databricks experiments get-permissions` command. + * Added `databricks experiments restore-runs` command. + * Added `databricks experiments set-permissions` command. + * Added `databricks experiments update-permissions` command. + * Added `databricks model-registry get-permission-levels` command. + * Added `databricks model-registry get-permissions` command. + * Added `databricks model-registry set-permissions` command. + * Added `databricks model-registry update-permissions` command. + * Added `databricks pipelines get-permission-levels` command. + * Added `databricks pipelines get-permissions` command. + * Added `databricks pipelines set-permissions` command. + * Added `databricks pipelines update-permissions` command. + * Added `databricks serving-endpoints get-permission-levels` command. + * Added `databricks serving-endpoints get-permissions` command. + * Added `databricks serving-endpoints set-permissions` command. + * Added `databricks serving-endpoints update-permissions` command. + * Added `databricks token-management get-permission-levels` command. + * Added `databricks token-management get-permissions` command. + * Added `databricks token-management set-permissions` command. + * Added `databricks token-management update-permissions` command. + * Changed `databricks dashboards create` command with new required argument order. + * Added `databricks warehouses get-permission-levels` command. + * Added `databricks warehouses get-permissions` command. + * Added `databricks warehouses set-permissions` command. + * Added `databricks warehouses update-permissions` command. + * Added `databricks dashboard-widgets` command group. + * Added `databricks query-visualizations` command group. + * Added `databricks repos get-permission-levels` command. + * Added `databricks repos get-permissions` command. + * Added `databricks repos set-permissions` command. + * Added `databricks repos update-permissions` command. + * Added `databricks secrets get-secret` command. + * Added `databricks workspace get-permission-levels` command. + * Added `databricks workspace get-permissions` command. + * Added `databricks workspace set-permissions` command. + * Added `databricks workspace update-permissions` command. + +OpenAPI commit 09a7fa63d9ae243e5407941f200960ca14d48b07 (2023-09-04) + ## 0.203.3 Bundles: diff --git a/bundle/artifacts/artifacts_test.go b/bundle/artifacts/artifacts_test.go index 4c0a18f38..bbae44efa 100644 --- a/bundle/artifacts/artifacts_test.go +++ b/bundle/artifacts/artifacts_test.go @@ -105,6 +105,7 @@ func TestUploadArtifactFileToCorrectRemotePath(t *testing.T) { b.WorkspaceClient().Workspace.WithImpl(MockWorkspaceService{}) artifact := &config.Artifact{ + Type: "whl", Files: []config.ArtifactFile{ { Source: whlPath, @@ -118,4 +119,5 @@ func TestUploadArtifactFileToCorrectRemotePath(t *testing.T) { err := uploadArtifact(context.Background(), artifact, b) require.NoError(t, err) require.Regexp(t, regexp.MustCompile("/Users/test@databricks.com/whatever/.internal/[a-z0-9]+/test.whl"), artifact.Files[0].RemotePath) + require.Regexp(t, regexp.MustCompile("/Workspace/Users/test@databricks.com/whatever/.internal/[a-z0-9]+/test.whl"), artifact.Files[0].Libraries[0].Whl) } diff --git a/bundle/artifacts/whl/autodetect.go b/bundle/artifacts/whl/autodetect.go index 41d80bb76..29031e86d 100644 --- a/bundle/artifacts/whl/autodetect.go +++ b/bundle/artifacts/whl/autodetect.go @@ -27,9 +27,9 @@ func (m *detectPkg) Name() string { } func (m *detectPkg) Apply(ctx context.Context, b *bundle.Bundle) error { - wheelTasks := libraries.FindAllWheelTasks(b) + wheelTasks := libraries.FindAllWheelTasksWithLocalLibraries(b) if len(wheelTasks) == 0 { - log.Infof(ctx, "No wheel tasks in databricks.yml config, skipping auto detect") + log.Infof(ctx, "No local wheel tasks in databricks.yml config, skipping auto detect") return nil } cmdio.LogString(ctx, "artifacts.whl.AutoDetect: Detecting Python wheel project...") diff --git a/bundle/artifacts/whl/from_libraries.go b/bundle/artifacts/whl/from_libraries.go index 855e5b943..9d35f6314 100644 --- a/bundle/artifacts/whl/from_libraries.go +++ b/bundle/artifacts/whl/from_libraries.go @@ -26,7 +26,7 @@ func (*fromLibraries) Apply(ctx context.Context, b *bundle.Bundle) error { return nil } - tasks := libraries.FindAllWheelTasks(b) + tasks := libraries.FindAllWheelTasksWithLocalLibraries(b) for _, task := range tasks { for _, lib := range task.Libraries { matches, err := filepath.Glob(filepath.Join(b.Config.Path, lib.Whl)) diff --git a/bundle/bundle.go b/bundle/bundle.go index d69d58158..4fc605398 100644 --- a/bundle/bundle.go +++ b/bundle/bundle.go @@ -14,6 +14,7 @@ import ( "sync" "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/env" "github.com/databricks/cli/folders" "github.com/databricks/cli/libs/git" "github.com/databricks/cli/libs/locker" @@ -37,6 +38,10 @@ type Bundle struct { // Stores an initialized copy of this bundle's Terraform wrapper. Terraform *tfexec.Terraform + // Indicates that the Terraform definition based on this bundle is empty, + // i.e. that it would deploy no resources. + TerraformHasNoResources bool + // Stores the locker responsible for acquiring/releasing a deployment lock. Locker *locker.Locker @@ -47,8 +52,6 @@ type Bundle struct { AutoApprove bool } -const ExtraIncludePathsKey string = "DATABRICKS_BUNDLE_INCLUDES" - func Load(ctx context.Context, path string) (*Bundle, error) { bundle := &Bundle{} stat, err := os.Stat(path) @@ -57,9 +60,9 @@ func Load(ctx context.Context, path string) (*Bundle, error) { } configFile, err := config.FileNames.FindInPath(path) if err != nil { - _, hasIncludePathEnv := os.LookupEnv(ExtraIncludePathsKey) - _, hasBundleRootEnv := os.LookupEnv(envBundleRoot) - if hasIncludePathEnv && hasBundleRootEnv && stat.IsDir() { + _, hasRootEnv := env.Root(ctx) + _, hasIncludesEnv := env.Includes(ctx) + if hasRootEnv && hasIncludesEnv && stat.IsDir() { log.Debugf(ctx, "No bundle configuration; using bundle root: %s", path) bundle.Config = config.Root{ Path: path, @@ -82,7 +85,7 @@ func Load(ctx context.Context, path string) (*Bundle, error) { // MustLoad returns a bundle configuration. // It returns an error if a bundle was not found or could not be loaded. func MustLoad(ctx context.Context) (*Bundle, error) { - root, err := mustGetRoot() + root, err := mustGetRoot(ctx) if err != nil { return nil, err } @@ -94,7 +97,7 @@ func MustLoad(ctx context.Context) (*Bundle, error) { // It returns an error if a bundle was found but could not be loaded. // It returns a `nil` bundle if a bundle was not found. func TryLoad(ctx context.Context) (*Bundle, error) { - root, err := tryGetRoot() + root, err := tryGetRoot(ctx) if err != nil { return nil, err } @@ -120,13 +123,12 @@ func (b *Bundle) WorkspaceClient() *databricks.WorkspaceClient { // CacheDir returns directory to use for temporary files for this bundle. // Scoped to the bundle's target. -func (b *Bundle) CacheDir(paths ...string) (string, error) { +func (b *Bundle) CacheDir(ctx context.Context, paths ...string) (string, error) { if b.Config.Bundle.Target == "" { panic("target not set") } - cacheDirName, exists := os.LookupEnv("DATABRICKS_BUNDLE_TMP") - + cacheDirName, exists := env.TempDir(ctx) if !exists || cacheDirName == "" { cacheDirName = filepath.Join( // Anchor at bundle root directory. @@ -159,8 +161,8 @@ func (b *Bundle) CacheDir(paths ...string) (string, error) { // This directory is used to store and automaticaly sync internal bundle files, such as, f.e // notebook trampoline files for Python wheel and etc. -func (b *Bundle) InternalDir() (string, error) { - cacheDir, err := b.CacheDir() +func (b *Bundle) InternalDir(ctx context.Context) (string, error) { + cacheDir, err := b.CacheDir(ctx) if err != nil { return "", err } @@ -177,8 +179,8 @@ func (b *Bundle) InternalDir() (string, error) { // GetSyncIncludePatterns returns a list of user defined includes // And also adds InternalDir folder to include list for sync command // so this folder is always synced -func (b *Bundle) GetSyncIncludePatterns() ([]string, error) { - internalDir, err := b.InternalDir() +func (b *Bundle) GetSyncIncludePatterns(ctx context.Context) ([]string, error) { + internalDir, err := b.InternalDir(ctx) if err != nil { return nil, err } diff --git a/bundle/bundle_test.go b/bundle/bundle_test.go index 4a3e7f2c9..43477efd1 100644 --- a/bundle/bundle_test.go +++ b/bundle/bundle_test.go @@ -6,6 +6,7 @@ import ( "path/filepath" "testing" + "github.com/databricks/cli/bundle/env" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -23,12 +24,13 @@ func TestLoadExists(t *testing.T) { } func TestBundleCacheDir(t *testing.T) { + ctx := context.Background() projectDir := t.TempDir() f1, err := os.Create(filepath.Join(projectDir, "databricks.yml")) require.NoError(t, err) f1.Close() - bundle, err := Load(context.Background(), projectDir) + bundle, err := Load(ctx, projectDir) require.NoError(t, err) // Artificially set target. @@ -38,7 +40,7 @@ func TestBundleCacheDir(t *testing.T) { // unset env variable in case it's set t.Setenv("DATABRICKS_BUNDLE_TMP", "") - cacheDir, err := bundle.CacheDir() + cacheDir, err := bundle.CacheDir(ctx) // format is /.databricks/bundle/ assert.NoError(t, err) @@ -46,13 +48,14 @@ func TestBundleCacheDir(t *testing.T) { } func TestBundleCacheDirOverride(t *testing.T) { + ctx := context.Background() projectDir := t.TempDir() bundleTmpDir := t.TempDir() f1, err := os.Create(filepath.Join(projectDir, "databricks.yml")) require.NoError(t, err) f1.Close() - bundle, err := Load(context.Background(), projectDir) + bundle, err := Load(ctx, projectDir) require.NoError(t, err) // Artificially set target. @@ -62,7 +65,7 @@ func TestBundleCacheDirOverride(t *testing.T) { // now we expect to use 'bundleTmpDir' instead of CWD/.databricks/bundle t.Setenv("DATABRICKS_BUNDLE_TMP", bundleTmpDir) - cacheDir, err := bundle.CacheDir() + cacheDir, err := bundle.CacheDir(ctx) // format is / assert.NoError(t, err) @@ -70,14 +73,14 @@ func TestBundleCacheDirOverride(t *testing.T) { } func TestBundleMustLoadSuccess(t *testing.T) { - t.Setenv(envBundleRoot, "./tests/basic") + t.Setenv(env.RootVariable, "./tests/basic") b, err := MustLoad(context.Background()) require.NoError(t, err) assert.Equal(t, "tests/basic", filepath.ToSlash(b.Config.Path)) } func TestBundleMustLoadFailureWithEnv(t *testing.T) { - t.Setenv(envBundleRoot, "./tests/doesntexist") + t.Setenv(env.RootVariable, "./tests/doesntexist") _, err := MustLoad(context.Background()) require.Error(t, err, "not a directory") } @@ -89,14 +92,14 @@ func TestBundleMustLoadFailureIfNotFound(t *testing.T) { } func TestBundleTryLoadSuccess(t *testing.T) { - t.Setenv(envBundleRoot, "./tests/basic") + t.Setenv(env.RootVariable, "./tests/basic") b, err := TryLoad(context.Background()) require.NoError(t, err) assert.Equal(t, "tests/basic", filepath.ToSlash(b.Config.Path)) } func TestBundleTryLoadFailureWithEnv(t *testing.T) { - t.Setenv(envBundleRoot, "./tests/doesntexist") + t.Setenv(env.RootVariable, "./tests/doesntexist") _, err := TryLoad(context.Background()) require.Error(t, err, "not a directory") } diff --git a/bundle/config/artifact.go b/bundle/config/artifact.go index 1955e265d..d7048a02e 100644 --- a/bundle/config/artifact.go +++ b/bundle/config/artifact.go @@ -78,9 +78,13 @@ func (a *Artifact) NormalisePaths() { remotePath := path.Join(wsfsBase, f.RemotePath) for i := range f.Libraries { lib := f.Libraries[i] - switch a.Type { - case ArtifactPythonWheel: + if lib.Whl != "" { lib.Whl = remotePath + continue + } + if lib.Jar != "" { + lib.Jar = remotePath + continue } } diff --git a/bundle/config/interpolation/interpolation.go b/bundle/config/interpolation/interpolation.go index bf5bd169e..8ba0b8b1f 100644 --- a/bundle/config/interpolation/interpolation.go +++ b/bundle/config/interpolation/interpolation.go @@ -184,7 +184,7 @@ func (a *accumulator) Resolve(path string, seenPaths []string, fns ...LookupFunc // fetch the string node to resolve field, ok := a.strings[path] if !ok { - return fmt.Errorf("could not resolve reference %s", path) + return fmt.Errorf("no value found for interpolation reference: ${%s}", path) } // return early if the string field has no variables to interpolate diff --git a/bundle/config/interpolation/interpolation_test.go b/bundle/config/interpolation/interpolation_test.go index 83254c9b0..cccb6dc71 100644 --- a/bundle/config/interpolation/interpolation_test.go +++ b/bundle/config/interpolation/interpolation_test.go @@ -247,5 +247,5 @@ func TestInterpolationInvalidVariableReference(t *testing.T) { } err := expand(&config) - assert.ErrorContains(t, err, "could not resolve reference vars.foo") + assert.ErrorContains(t, err, "no value found for interpolation reference: ${vars.foo}") } diff --git a/bundle/config/mutator/override_compute.go b/bundle/config/mutator/override_compute.go index ee2e2a825..21d950135 100644 --- a/bundle/config/mutator/override_compute.go +++ b/bundle/config/mutator/override_compute.go @@ -3,11 +3,11 @@ package mutator import ( "context" "fmt" - "os" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config/resources" + "github.com/databricks/cli/libs/env" ) type overrideCompute struct{} @@ -39,8 +39,8 @@ func (m *overrideCompute) Apply(ctx context.Context, b *bundle.Bundle) error { } return nil } - if os.Getenv("DATABRICKS_CLUSTER_ID") != "" { - b.Config.Bundle.ComputeID = os.Getenv("DATABRICKS_CLUSTER_ID") + if v := env.Get(ctx, "DATABRICKS_CLUSTER_ID"); v != "" { + b.Config.Bundle.ComputeID = v } if b.Config.Bundle.ComputeID == "" { diff --git a/bundle/config/mutator/populate_current_user.go b/bundle/config/mutator/populate_current_user.go index cbaa2d30b..bba0457c4 100644 --- a/bundle/config/mutator/populate_current_user.go +++ b/bundle/config/mutator/populate_current_user.go @@ -21,6 +21,10 @@ func (m *populateCurrentUser) Name() string { } func (m *populateCurrentUser) Apply(ctx context.Context, b *bundle.Bundle) error { + if b.Config.Workspace.CurrentUser != nil { + return nil + } + w := b.WorkspaceClient() me, err := w.CurrentUser.Me(ctx) if err != nil { diff --git a/bundle/config/mutator/process_root_includes.go b/bundle/config/mutator/process_root_includes.go index 989928721..5a5ab1b19 100644 --- a/bundle/config/mutator/process_root_includes.go +++ b/bundle/config/mutator/process_root_includes.go @@ -10,11 +10,12 @@ import ( "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/env" ) // Get extra include paths from environment variable -func GetExtraIncludePaths() []string { - value, exists := os.LookupEnv(bundle.ExtraIncludePathsKey) +func getExtraIncludePaths(ctx context.Context) []string { + value, exists := env.Includes(ctx) if !exists { return nil } @@ -48,7 +49,7 @@ func (m *processRootIncludes) Apply(ctx context.Context, b *bundle.Bundle) error var files []string // Converts extra include paths from environment variable to relative paths - for _, extraIncludePath := range GetExtraIncludePaths() { + for _, extraIncludePath := range getExtraIncludePaths(ctx) { if filepath.IsAbs(extraIncludePath) { rel, err := filepath.Rel(b.Config.Path, extraIncludePath) if err != nil { diff --git a/bundle/config/mutator/process_root_includes_test.go b/bundle/config/mutator/process_root_includes_test.go index 1ce094bc3..aec9b32df 100644 --- a/bundle/config/mutator/process_root_includes_test.go +++ b/bundle/config/mutator/process_root_includes_test.go @@ -2,16 +2,17 @@ package mutator_test import ( "context" - "fmt" "os" "path" "path/filepath" "runtime" + "strings" "testing" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config/mutator" + "github.com/databricks/cli/bundle/env" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -129,10 +130,7 @@ func TestProcessRootIncludesExtrasFromEnvVar(t *testing.T) { rootPath := t.TempDir() testYamlName := "extra_include_path.yml" touch(t, rootPath, testYamlName) - os.Setenv(bundle.ExtraIncludePathsKey, path.Join(rootPath, testYamlName)) - t.Cleanup(func() { - os.Unsetenv(bundle.ExtraIncludePathsKey) - }) + t.Setenv(env.IncludesVariable, path.Join(rootPath, testYamlName)) bundle := &bundle.Bundle{ Config: config.Root{ @@ -149,7 +147,13 @@ func TestProcessRootIncludesDedupExtrasFromEnvVar(t *testing.T) { rootPath := t.TempDir() testYamlName := "extra_include_path.yml" touch(t, rootPath, testYamlName) - t.Setenv(bundle.ExtraIncludePathsKey, fmt.Sprintf("%s%s%s", path.Join(rootPath, testYamlName), string(os.PathListSeparator), path.Join(rootPath, testYamlName))) + t.Setenv(env.IncludesVariable, strings.Join( + []string{ + path.Join(rootPath, testYamlName), + path.Join(rootPath, testYamlName), + }, + string(os.PathListSeparator), + )) bundle := &bundle.Bundle{ Config: config.Root{ diff --git a/bundle/config/mutator/process_target_mode.go b/bundle/config/mutator/process_target_mode.go index 06ae7b858..93149ad04 100644 --- a/bundle/config/mutator/process_target_mode.go +++ b/bundle/config/mutator/process_target_mode.go @@ -77,6 +77,12 @@ func transformDevelopmentMode(b *bundle.Bundle) error { r.Experiments[i].Tags = append(r.Experiments[i].Tags, ml.ExperimentTag{Key: "dev", Value: b.Config.Workspace.CurrentUser.DisplayName}) } + for i := range r.ModelServingEndpoints { + prefix = "dev_" + b.Config.Workspace.CurrentUser.ShortName + "_" + r.ModelServingEndpoints[i].Name = prefix + r.ModelServingEndpoints[i].Name + // (model serving doesn't yet support tags) + } + return nil } diff --git a/bundle/config/mutator/process_target_mode_test.go b/bundle/config/mutator/process_target_mode_test.go index 489632e17..4ea33c70b 100644 --- a/bundle/config/mutator/process_target_mode_test.go +++ b/bundle/config/mutator/process_target_mode_test.go @@ -13,6 +13,7 @@ import ( "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/ml" "github.com/databricks/databricks-sdk-go/service/pipelines" + "github.com/databricks/databricks-sdk-go/service/serving" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -53,6 +54,9 @@ func mockBundle(mode config.Mode) *bundle.Bundle { Models: map[string]*resources.MlflowModel{ "model1": {Model: &ml.Model{Name: "model1"}}, }, + ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{ + "servingendpoint1": {CreateServingEndpoint: &serving.CreateServingEndpoint{Name: "servingendpoint1"}}, + }, }, }, } @@ -69,6 +73,7 @@ func TestProcessTargetModeDevelopment(t *testing.T) { assert.Equal(t, "/Users/lennart.kats@databricks.com/[dev lennart] experiment1", bundle.Config.Resources.Experiments["experiment1"].Name) assert.Equal(t, "[dev lennart] experiment2", bundle.Config.Resources.Experiments["experiment2"].Name) assert.Equal(t, "[dev lennart] model1", bundle.Config.Resources.Models["model1"].Name) + assert.Equal(t, "dev_lennart_servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name) assert.Equal(t, "dev", bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags[0].Key) assert.True(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development) } @@ -82,6 +87,7 @@ func TestProcessTargetModeDefault(t *testing.T) { assert.Equal(t, "job1", bundle.Config.Resources.Jobs["job1"].Name) assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name) assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development) + assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name) } func TestProcessTargetModeProduction(t *testing.T) { @@ -109,6 +115,7 @@ func TestProcessTargetModeProduction(t *testing.T) { bundle.Config.Resources.Experiments["experiment1"].Permissions = permissions bundle.Config.Resources.Experiments["experiment2"].Permissions = permissions bundle.Config.Resources.Models["model1"].Permissions = permissions + bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Permissions = permissions err = validateProductionMode(context.Background(), bundle, false) require.NoError(t, err) @@ -116,6 +123,7 @@ func TestProcessTargetModeProduction(t *testing.T) { assert.Equal(t, "job1", bundle.Config.Resources.Jobs["job1"].Name) assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name) assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development) + assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name) } func TestProcessTargetModeProductionOkForPrincipal(t *testing.T) { diff --git a/bundle/config/mutator/select_target.go b/bundle/config/mutator/select_target.go index 3be1f2e1a..2ad431128 100644 --- a/bundle/config/mutator/select_target.go +++ b/bundle/config/mutator/select_target.go @@ -3,8 +3,10 @@ package mutator import ( "context" "fmt" + "strings" "github.com/databricks/cli/bundle" + "golang.org/x/exp/maps" ) type selectTarget struct { @@ -30,7 +32,7 @@ func (m *selectTarget) Apply(_ context.Context, b *bundle.Bundle) error { // Get specified target target, ok := b.Config.Targets[m.name] if !ok { - return fmt.Errorf("%s: no such target", m.name) + return fmt.Errorf("%s: no such target. Available targets: %s", m.name, strings.Join(maps.Keys(b.Config.Targets), ", ")) } // Merge specified target into root configuration structure. diff --git a/bundle/config/mutator/set_variables.go b/bundle/config/mutator/set_variables.go index 427b6dce4..4bf8ff82a 100644 --- a/bundle/config/mutator/set_variables.go +++ b/bundle/config/mutator/set_variables.go @@ -3,10 +3,10 @@ package mutator import ( "context" "fmt" - "os" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config/variable" + "github.com/databricks/cli/libs/env" ) const bundleVarPrefix = "BUNDLE_VAR_" @@ -21,7 +21,7 @@ func (m *setVariables) Name() string { return "SetVariables" } -func setVariable(v *variable.Variable, name string) error { +func setVariable(ctx context.Context, v *variable.Variable, name string) error { // case: variable already has value initialized, so skip if v.HasValue() { return nil @@ -29,7 +29,7 @@ func setVariable(v *variable.Variable, name string) error { // case: read and set variable value from process environment envVarName := bundleVarPrefix + name - if val, ok := os.LookupEnv(envVarName); ok { + if val, ok := env.Lookup(ctx, envVarName); ok { err := v.Set(val) if err != nil { return fmt.Errorf(`failed to assign value "%s" to variable %s from environment variable %s with error: %w`, val, name, envVarName, err) @@ -54,7 +54,7 @@ func setVariable(v *variable.Variable, name string) error { func (m *setVariables) Apply(ctx context.Context, b *bundle.Bundle) error { for name, variable := range b.Config.Variables { - err := setVariable(variable, name) + err := setVariable(ctx, variable, name) if err != nil { return err } diff --git a/bundle/config/mutator/set_variables_test.go b/bundle/config/mutator/set_variables_test.go index 91948aa4b..323f1e864 100644 --- a/bundle/config/mutator/set_variables_test.go +++ b/bundle/config/mutator/set_variables_test.go @@ -21,7 +21,7 @@ func TestSetVariableFromProcessEnvVar(t *testing.T) { // set value for variable as an environment variable t.Setenv("BUNDLE_VAR_foo", "process-env") - err := setVariable(&variable, "foo") + err := setVariable(context.Background(), &variable, "foo") require.NoError(t, err) assert.Equal(t, *variable.Value, "process-env") } @@ -33,7 +33,7 @@ func TestSetVariableUsingDefaultValue(t *testing.T) { Default: &defaultVal, } - err := setVariable(&variable, "foo") + err := setVariable(context.Background(), &variable, "foo") require.NoError(t, err) assert.Equal(t, *variable.Value, "default") } @@ -49,7 +49,7 @@ func TestSetVariableWhenAlreadyAValueIsAssigned(t *testing.T) { // since a value is already assigned to the variable, it would not be overridden // by the default value - err := setVariable(&variable, "foo") + err := setVariable(context.Background(), &variable, "foo") require.NoError(t, err) assert.Equal(t, *variable.Value, "assigned-value") } @@ -68,7 +68,7 @@ func TestSetVariableEnvVarValueDoesNotOverridePresetValue(t *testing.T) { // since a value is already assigned to the variable, it would not be overridden // by the value from environment - err := setVariable(&variable, "foo") + err := setVariable(context.Background(), &variable, "foo") require.NoError(t, err) assert.Equal(t, *variable.Value, "assigned-value") } @@ -79,7 +79,7 @@ func TestSetVariablesErrorsIfAValueCouldNotBeResolved(t *testing.T) { } // fails because we could not resolve a value for the variable - err := setVariable(&variable, "foo") + err := setVariable(context.Background(), &variable, "foo") assert.ErrorContains(t, err, "no value assigned to required variable foo. Assignment can be done through the \"--var\" flag or by setting the BUNDLE_VAR_foo environment variable") } diff --git a/bundle/config/mutator/trampoline.go b/bundle/config/mutator/trampoline.go index ea938fdcf..809f41b27 100644 --- a/bundle/config/mutator/trampoline.go +++ b/bundle/config/mutator/trampoline.go @@ -38,7 +38,7 @@ func (m *trampoline) Name() string { func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error { tasks := m.functions.GetTasks(b) for _, task := range tasks { - err := m.generateNotebookWrapper(b, task) + err := m.generateNotebookWrapper(ctx, b, task) if err != nil { return err } @@ -46,8 +46,8 @@ func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error { return nil } -func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task jobs_utils.TaskWithJobKey) error { - internalDir, err := b.InternalDir() +func (m *trampoline) generateNotebookWrapper(ctx context.Context, b *bundle.Bundle, task jobs_utils.TaskWithJobKey) error { + internalDir, err := b.InternalDir(ctx) if err != nil { return err } diff --git a/bundle/config/mutator/trampoline_test.go b/bundle/config/mutator/trampoline_test.go index df26f6166..de8f1b5e9 100644 --- a/bundle/config/mutator/trampoline_test.go +++ b/bundle/config/mutator/trampoline_test.go @@ -11,6 +11,7 @@ import ( "github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config/paths" "github.com/databricks/cli/bundle/config/resources" + jobs_utils "github.com/databricks/cli/libs/jobs" "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/stretchr/testify/require" ) @@ -18,10 +19,10 @@ import ( type functions struct { } -func (f *functions) GetTasks(b *bundle.Bundle) []TaskWithJobKey { - tasks := make([]TaskWithJobKey, 0) +func (f *functions) GetTasks(b *bundle.Bundle) []jobs_utils.TaskWithJobKey { + tasks := make([]jobs_utils.TaskWithJobKey, 0) for k := range b.Config.Resources.Jobs["test"].Tasks { - tasks = append(tasks, TaskWithJobKey{ + tasks = append(tasks, jobs_utils.TaskWithJobKey{ JobKey: "test", Task: &b.Config.Resources.Jobs["test"].Tasks[k], }) @@ -88,7 +89,7 @@ func TestGenerateTrampoline(t *testing.T) { err := bundle.Apply(ctx, b, trampoline) require.NoError(t, err) - dir, err := b.InternalDir() + dir, err := b.InternalDir(ctx) require.NoError(t, err) filename := filepath.Join(dir, "notebook_test_trampoline_test_to_trampoline.py") diff --git a/bundle/config/mutator/translate_paths_test.go b/bundle/config/mutator/translate_paths_test.go index e7ac5e8af..f7edee30a 100644 --- a/bundle/config/mutator/translate_paths_test.go +++ b/bundle/config/mutator/translate_paths_test.go @@ -162,7 +162,7 @@ func TestTranslatePaths(t *testing.T) { MainClassName: "HelloWorldRemote", }, Libraries: []compute.Library{ - {Jar: "dbfs:///bundle/dist/task_remote.jar"}, + {Jar: "dbfs:/bundle/dist/task_remote.jar"}, }, }, }, @@ -243,7 +243,7 @@ func TestTranslatePaths(t *testing.T) { ) assert.Equal( t, - "dbfs:///bundle/dist/task_remote.jar", + "dbfs:/bundle/dist/task_remote.jar", bundle.Config.Resources.Jobs["job"].Tasks[6].Libraries[0].Jar, ) diff --git a/bundle/config/resources.go b/bundle/config/resources.go index 5d47b918c..c239b510b 100644 --- a/bundle/config/resources.go +++ b/bundle/config/resources.go @@ -11,8 +11,9 @@ type Resources struct { Jobs map[string]*resources.Job `json:"jobs,omitempty"` Pipelines map[string]*resources.Pipeline `json:"pipelines,omitempty"` - Models map[string]*resources.MlflowModel `json:"models,omitempty"` - Experiments map[string]*resources.MlflowExperiment `json:"experiments,omitempty"` + Models map[string]*resources.MlflowModel `json:"models,omitempty"` + Experiments map[string]*resources.MlflowExperiment `json:"experiments,omitempty"` + ModelServingEndpoints map[string]*resources.ModelServingEndpoint `json:"model_serving_endpoints,omitempty"` } type UniqueResourceIdTracker struct { @@ -93,6 +94,19 @@ func (r *Resources) VerifyUniqueResourceIdentifiers() (*UniqueResourceIdTracker, tracker.Type[k] = "mlflow_experiment" tracker.ConfigPath[k] = r.Experiments[k].ConfigFilePath } + for k := range r.ModelServingEndpoints { + if _, ok := tracker.Type[k]; ok { + return tracker, fmt.Errorf("multiple resources named %s (%s at %s, %s at %s)", + k, + tracker.Type[k], + tracker.ConfigPath[k], + "model_serving_endpoint", + r.ModelServingEndpoints[k].ConfigFilePath, + ) + } + tracker.Type[k] = "model_serving_endpoint" + tracker.ConfigPath[k] = r.ModelServingEndpoints[k].ConfigFilePath + } return tracker, nil } @@ -112,6 +126,9 @@ func (r *Resources) SetConfigFilePath(path string) { for _, e := range r.Experiments { e.ConfigFilePath = path } + for _, e := range r.ModelServingEndpoints { + e.ConfigFilePath = path + } } // MergeJobClusters iterates over all jobs and merges their job clusters. diff --git a/bundle/config/resources/model_serving_endpoint.go b/bundle/config/resources/model_serving_endpoint.go new file mode 100644 index 000000000..dccecaa6f --- /dev/null +++ b/bundle/config/resources/model_serving_endpoint.go @@ -0,0 +1,24 @@ +package resources + +import ( + "github.com/databricks/cli/bundle/config/paths" + "github.com/databricks/databricks-sdk-go/service/serving" +) + +type ModelServingEndpoint struct { + // This represents the input args for terraform, and will get converted + // to a HCL representation for CRUD + *serving.CreateServingEndpoint + + // This represents the id (ie serving_endpoint_id) that can be used + // as a reference in other resources. This value is returned by terraform. + ID string + + // Local path where the bundle is defined. All bundle resources include + // this for interpolation purposes. + paths.Paths + + // This is a resource agnostic implementation of permissions for ACLs. + // Implementation could be different based on the resource type. + Permissions []Permission `json:"permissions,omitempty"` +} diff --git a/bundle/config/root.go b/bundle/config/root.go index 99ea33ad6..0377f60a0 100644 --- a/bundle/config/root.go +++ b/bundle/config/root.go @@ -52,7 +52,7 @@ type Root struct { // Bundle contains details about this bundle, such as its name, // version of the spec (TODO), default cluster, default warehouse, etc. - Bundle Bundle `json:"bundle"` + Bundle Bundle `json:"bundle,omitempty"` // Include specifies a list of patterns of file names to load and // merge into the this configuration. Only includes defined in the root @@ -80,7 +80,7 @@ type Root struct { Environments map[string]*Target `json:"environments,omitempty"` // Sync section specifies options for files synchronization - Sync Sync `json:"sync"` + Sync Sync `json:"sync,omitempty"` // RunAs section allows to define an execution identity for jobs and pipelines runs RunAs *jobs.JobRunAs `json:"run_as,omitempty"` diff --git a/bundle/deploy/files/sync.go b/bundle/deploy/files/sync.go index 2dccd20a7..ff3d78d07 100644 --- a/bundle/deploy/files/sync.go +++ b/bundle/deploy/files/sync.go @@ -9,12 +9,12 @@ import ( ) func getSync(ctx context.Context, b *bundle.Bundle) (*sync.Sync, error) { - cacheDir, err := b.CacheDir() + cacheDir, err := b.CacheDir(ctx) if err != nil { return nil, fmt.Errorf("cannot get bundle cache directory: %w", err) } - includes, err := b.GetSyncIncludePatterns() + includes, err := b.GetSyncIncludePatterns(ctx) if err != nil { return nil, fmt.Errorf("cannot get list of sync includes: %w", err) } diff --git a/bundle/deploy/terraform/apply.go b/bundle/deploy/terraform/apply.go index ab868f765..53cffbbaf 100644 --- a/bundle/deploy/terraform/apply.go +++ b/bundle/deploy/terraform/apply.go @@ -16,6 +16,10 @@ func (w *apply) Name() string { } func (w *apply) Apply(ctx context.Context, b *bundle.Bundle) error { + if b.TerraformHasNoResources { + cmdio.LogString(ctx, "Note: there are no resources to deploy for this bundle") + return nil + } tf := b.Terraform if tf == nil { return fmt.Errorf("terraform not initialized") diff --git a/bundle/deploy/terraform/convert.go b/bundle/deploy/terraform/convert.go index ac68bd359..0956ea7bb 100644 --- a/bundle/deploy/terraform/convert.go +++ b/bundle/deploy/terraform/convert.go @@ -49,12 +49,14 @@ func convPermission(ac resources.Permission) schema.ResourcePermissionsAccessCon // // NOTE: THIS IS CURRENTLY A HACK. WE NEED A BETTER WAY TO // CONVERT TO/FROM TERRAFORM COMPATIBLE FORMAT. -func BundleToTerraform(config *config.Root) *schema.Root { +func BundleToTerraform(config *config.Root) (*schema.Root, bool) { tfroot := schema.NewRoot() tfroot.Provider = schema.NewProviders() tfroot.Resource = schema.NewResources() + noResources := true for k, src := range config.Resources.Jobs { + noResources = false var dst schema.ResourceJob conv(src, &dst) @@ -88,6 +90,12 @@ func BundleToTerraform(config *config.Root) *schema.Root { Tag: git.GitTag, } } + + for _, v := range src.Parameters { + var t schema.ResourceJobParameter + conv(v, &t) + dst.Parameter = append(dst.Parameter, t) + } } tfroot.Resource.Job[k] = &dst @@ -100,6 +108,7 @@ func BundleToTerraform(config *config.Root) *schema.Root { } for k, src := range config.Resources.Pipelines { + noResources = false var dst schema.ResourcePipeline conv(src, &dst) @@ -127,6 +136,7 @@ func BundleToTerraform(config *config.Root) *schema.Root { } for k, src := range config.Resources.Models { + noResources = false var dst schema.ResourceMlflowModel conv(src, &dst) tfroot.Resource.MlflowModel[k] = &dst @@ -139,6 +149,7 @@ func BundleToTerraform(config *config.Root) *schema.Root { } for k, src := range config.Resources.Experiments { + noResources = false var dst schema.ResourceMlflowExperiment conv(src, &dst) tfroot.Resource.MlflowExperiment[k] = &dst @@ -150,7 +161,20 @@ func BundleToTerraform(config *config.Root) *schema.Root { } } - return tfroot + for k, src := range config.Resources.ModelServingEndpoints { + noResources = false + var dst schema.ResourceModelServing + conv(src, &dst) + tfroot.Resource.ModelServing[k] = &dst + + // Configure permissions for this resource. + if rp := convPermissions(src.Permissions); rp != nil { + rp.ServingEndpointId = fmt.Sprintf("${databricks_model_serving.%s.serving_endpoint_id}", k) + tfroot.Resource.Permissions["model_serving_"+k] = rp + } + } + + return tfroot, noResources } func TerraformToBundle(state *tfjson.State, config *config.Root) error { @@ -185,6 +209,12 @@ func TerraformToBundle(state *tfjson.State, config *config.Root) error { cur := config.Resources.Experiments[resource.Name] conv(tmp, &cur) config.Resources.Experiments[resource.Name] = cur + case "databricks_model_serving": + var tmp schema.ResourceModelServing + conv(resource.AttributeValues, &tmp) + cur := config.Resources.ModelServingEndpoints[resource.Name] + conv(tmp, &cur) + config.Resources.ModelServingEndpoints[resource.Name] = cur case "databricks_permissions": // Ignore; no need to pull these back into the configuration. default: diff --git a/bundle/deploy/terraform/convert_test.go b/bundle/deploy/terraform/convert_test.go index c47824ec5..ad6266066 100644 --- a/bundle/deploy/terraform/convert_test.go +++ b/bundle/deploy/terraform/convert_test.go @@ -9,6 +9,7 @@ import ( "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/ml" "github.com/databricks/databricks-sdk-go/service/pipelines" + "github.com/databricks/databricks-sdk-go/service/serving" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -29,6 +30,16 @@ func TestConvertJob(t *testing.T) { GitProvider: jobs.GitProviderGitHub, GitUrl: "https://github.com/foo/bar", }, + Parameters: []jobs.JobParameterDefinition{ + { + Name: "param1", + Default: "default1", + }, + { + Name: "param2", + Default: "default2", + }, + }, }, } @@ -40,10 +51,13 @@ func TestConvertJob(t *testing.T) { }, } - out := BundleToTerraform(&config) + out, _ := BundleToTerraform(&config) assert.Equal(t, "my job", out.Resource.Job["my_job"].Name) assert.Len(t, out.Resource.Job["my_job"].JobCluster, 1) assert.Equal(t, "https://github.com/foo/bar", out.Resource.Job["my_job"].GitSource.Url) + assert.Len(t, out.Resource.Job["my_job"].Parameter, 2) + assert.Equal(t, "param1", out.Resource.Job["my_job"].Parameter[0].Name) + assert.Equal(t, "param2", out.Resource.Job["my_job"].Parameter[1].Name) assert.Nil(t, out.Data) } @@ -65,7 +79,7 @@ func TestConvertJobPermissions(t *testing.T) { }, } - out := BundleToTerraform(&config) + out, _ := BundleToTerraform(&config) assert.NotEmpty(t, out.Resource.Permissions["job_my_job"].JobId) assert.Len(t, out.Resource.Permissions["job_my_job"].AccessControl, 1) @@ -101,7 +115,7 @@ func TestConvertJobTaskLibraries(t *testing.T) { }, } - out := BundleToTerraform(&config) + out, _ := BundleToTerraform(&config) assert.Equal(t, "my job", out.Resource.Job["my_job"].Name) require.Len(t, out.Resource.Job["my_job"].Task, 1) require.Len(t, out.Resource.Job["my_job"].Task[0].Library, 1) @@ -135,7 +149,7 @@ func TestConvertPipeline(t *testing.T) { }, } - out := BundleToTerraform(&config) + out, _ := BundleToTerraform(&config) assert.Equal(t, "my pipeline", out.Resource.Pipeline["my_pipeline"].Name) assert.Len(t, out.Resource.Pipeline["my_pipeline"].Library, 2) assert.Nil(t, out.Data) @@ -159,7 +173,7 @@ func TestConvertPipelinePermissions(t *testing.T) { }, } - out := BundleToTerraform(&config) + out, _ := BundleToTerraform(&config) assert.NotEmpty(t, out.Resource.Permissions["pipeline_my_pipeline"].PipelineId) assert.Len(t, out.Resource.Permissions["pipeline_my_pipeline"].AccessControl, 1) @@ -194,7 +208,7 @@ func TestConvertModel(t *testing.T) { }, } - out := BundleToTerraform(&config) + out, _ := BundleToTerraform(&config) assert.Equal(t, "name", out.Resource.MlflowModel["my_model"].Name) assert.Equal(t, "description", out.Resource.MlflowModel["my_model"].Description) assert.Len(t, out.Resource.MlflowModel["my_model"].Tags, 2) @@ -223,7 +237,7 @@ func TestConvertModelPermissions(t *testing.T) { }, } - out := BundleToTerraform(&config) + out, _ := BundleToTerraform(&config) assert.NotEmpty(t, out.Resource.Permissions["mlflow_model_my_model"].RegisteredModelId) assert.Len(t, out.Resource.Permissions["mlflow_model_my_model"].AccessControl, 1) @@ -247,7 +261,7 @@ func TestConvertExperiment(t *testing.T) { }, } - out := BundleToTerraform(&config) + out, _ := BundleToTerraform(&config) assert.Equal(t, "name", out.Resource.MlflowExperiment["my_experiment"].Name) assert.Nil(t, out.Data) } @@ -270,7 +284,7 @@ func TestConvertExperimentPermissions(t *testing.T) { }, } - out := BundleToTerraform(&config) + out, _ := BundleToTerraform(&config) assert.NotEmpty(t, out.Resource.Permissions["mlflow_experiment_my_experiment"].ExperimentId) assert.Len(t, out.Resource.Permissions["mlflow_experiment_my_experiment"].AccessControl, 1) @@ -279,3 +293,76 @@ func TestConvertExperimentPermissions(t *testing.T) { assert.Equal(t, "CAN_READ", p.PermissionLevel) } + +func TestConvertModelServing(t *testing.T) { + var src = resources.ModelServingEndpoint{ + CreateServingEndpoint: &serving.CreateServingEndpoint{ + Name: "name", + Config: serving.EndpointCoreConfigInput{ + ServedModels: []serving.ServedModelInput{ + { + ModelName: "model_name", + ModelVersion: "1", + ScaleToZeroEnabled: true, + WorkloadSize: "Small", + }, + }, + TrafficConfig: &serving.TrafficConfig{ + Routes: []serving.Route{ + { + ServedModelName: "model_name-1", + TrafficPercentage: 100, + }, + }, + }, + }, + }, + } + + var config = config.Root{ + Resources: config.Resources{ + ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{ + "my_model_serving_endpoint": &src, + }, + }, + } + + out, _ := BundleToTerraform(&config) + resource := out.Resource.ModelServing["my_model_serving_endpoint"] + assert.Equal(t, "name", resource.Name) + assert.Equal(t, "model_name", resource.Config.ServedModels[0].ModelName) + assert.Equal(t, "1", resource.Config.ServedModels[0].ModelVersion) + assert.Equal(t, true, resource.Config.ServedModels[0].ScaleToZeroEnabled) + assert.Equal(t, "Small", resource.Config.ServedModels[0].WorkloadSize) + assert.Equal(t, "model_name-1", resource.Config.TrafficConfig.Routes[0].ServedModelName) + assert.Equal(t, 100, resource.Config.TrafficConfig.Routes[0].TrafficPercentage) + assert.Nil(t, out.Data) +} + +func TestConvertModelServingPermissions(t *testing.T) { + var src = resources.ModelServingEndpoint{ + Permissions: []resources.Permission{ + { + Level: "CAN_VIEW", + UserName: "jane@doe.com", + }, + }, + } + + var config = config.Root{ + Resources: config.Resources{ + ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{ + "my_model_serving_endpoint": &src, + }, + }, + } + + out, _ := BundleToTerraform(&config) + assert.NotEmpty(t, out.Resource.Permissions["model_serving_my_model_serving_endpoint"].ServingEndpointId) + assert.Len(t, out.Resource.Permissions["model_serving_my_model_serving_endpoint"].AccessControl, 1) + + p := out.Resource.Permissions["model_serving_my_model_serving_endpoint"].AccessControl[0] + assert.Equal(t, "jane@doe.com", p.UserName) + assert.Equal(t, "CAN_VIEW", p.PermissionLevel) + +} diff --git a/bundle/deploy/terraform/dir.go b/bundle/deploy/terraform/dir.go index 9f83b8da5..b7b086ceb 100644 --- a/bundle/deploy/terraform/dir.go +++ b/bundle/deploy/terraform/dir.go @@ -1,11 +1,13 @@ package terraform import ( + "context" + "github.com/databricks/cli/bundle" ) // Dir returns the Terraform working directory for a given bundle. // The working directory is emphemeral and nested under the bundle's cache directory. -func Dir(b *bundle.Bundle) (string, error) { - return b.CacheDir("terraform") +func Dir(ctx context.Context, b *bundle.Bundle) (string, error) { + return b.CacheDir(ctx, "terraform") } diff --git a/bundle/deploy/terraform/init.go b/bundle/deploy/terraform/init.go index 6df7b8d48..aa1dff74e 100644 --- a/bundle/deploy/terraform/init.go +++ b/bundle/deploy/terraform/init.go @@ -8,9 +8,11 @@ import ( "path/filepath" "runtime" "strings" + "time" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/log" "github.com/hashicorp/go-version" "github.com/hashicorp/hc-install/product" @@ -37,7 +39,7 @@ func (m *initialize) findExecPath(ctx context.Context, b *bundle.Bundle, tf *con return tf.ExecPath, nil } - binDir, err := b.CacheDir("bin") + binDir, err := b.CacheDir(context.Background(), "bin") if err != nil { return "", err } @@ -59,6 +61,7 @@ func (m *initialize) findExecPath(ctx context.Context, b *bundle.Bundle, tf *con Product: product.Terraform, Version: version.Must(version.NewVersion("1.5.5")), InstallDir: binDir, + Timeout: 1 * time.Minute, } execPath, err = installer.Install(ctx) if err != nil { @@ -71,17 +74,25 @@ func (m *initialize) findExecPath(ctx context.Context, b *bundle.Bundle, tf *con } // This function inherits some environment variables for Terraform CLI. -func inheritEnvVars(env map[string]string) error { +func inheritEnvVars(ctx context.Context, environ map[string]string) error { // Include $HOME in set of environment variables to pass along. - home, ok := os.LookupEnv("HOME") + home, ok := env.Lookup(ctx, "HOME") if ok { - env["HOME"] = home + environ["HOME"] = home + } + + // Include $PATH in set of environment variables to pass along. + // This is necessary to ensure that our Terraform provider can use the + // same auxiliary programs (e.g. `az`, or `gcloud`) as the CLI. + path, ok := env.Lookup(ctx, "PATH") + if ok { + environ["PATH"] = path } // Include $TF_CLI_CONFIG_FILE to override terraform provider in development. - configFile, ok := os.LookupEnv("TF_CLI_CONFIG_FILE") + configFile, ok := env.Lookup(ctx, "TF_CLI_CONFIG_FILE") if ok { - env["TF_CLI_CONFIG_FILE"] = configFile + environ["TF_CLI_CONFIG_FILE"] = configFile } return nil @@ -95,40 +106,40 @@ func inheritEnvVars(env map[string]string) error { // the CLI and its dependencies do not have access to. // // see: os.TempDir for more context -func setTempDirEnvVars(env map[string]string, b *bundle.Bundle) error { +func setTempDirEnvVars(ctx context.Context, environ map[string]string, b *bundle.Bundle) error { switch runtime.GOOS { case "windows": - if v, ok := os.LookupEnv("TMP"); ok { - env["TMP"] = v - } else if v, ok := os.LookupEnv("TEMP"); ok { - env["TEMP"] = v - } else if v, ok := os.LookupEnv("USERPROFILE"); ok { - env["USERPROFILE"] = v + if v, ok := env.Lookup(ctx, "TMP"); ok { + environ["TMP"] = v + } else if v, ok := env.Lookup(ctx, "TEMP"); ok { + environ["TEMP"] = v + } else if v, ok := env.Lookup(ctx, "USERPROFILE"); ok { + environ["USERPROFILE"] = v } else { - tmpDir, err := b.CacheDir("tmp") + tmpDir, err := b.CacheDir(ctx, "tmp") if err != nil { return err } - env["TMP"] = tmpDir + environ["TMP"] = tmpDir } default: // If TMPDIR is not set, we let the process fall back to its default value. - if v, ok := os.LookupEnv("TMPDIR"); ok { - env["TMPDIR"] = v + if v, ok := env.Lookup(ctx, "TMPDIR"); ok { + environ["TMPDIR"] = v } } return nil } // This function passes through all proxy related environment variables. -func setProxyEnvVars(env map[string]string, b *bundle.Bundle) error { +func setProxyEnvVars(ctx context.Context, environ map[string]string, b *bundle.Bundle) error { for _, v := range []string{"http_proxy", "https_proxy", "no_proxy"} { // The case (upper or lower) is notoriously inconsistent for tools on Unix systems. // We therefore try to read both the upper and lower case versions of the variable. for _, v := range []string{strings.ToUpper(v), strings.ToLower(v)} { - if val, ok := os.LookupEnv(v); ok { + if val, ok := env.Lookup(ctx, v); ok { // Only set uppercase version of the variable. - env[strings.ToUpper(v)] = val + environ[strings.ToUpper(v)] = val } } } @@ -147,7 +158,7 @@ func (m *initialize) Apply(ctx context.Context, b *bundle.Bundle) error { return err } - workingDir, err := Dir(b) + workingDir, err := Dir(ctx, b) if err != nil { return err } @@ -157,31 +168,31 @@ func (m *initialize) Apply(ctx context.Context, b *bundle.Bundle) error { return err } - env, err := b.AuthEnv() + environ, err := b.AuthEnv() if err != nil { return err } - err = inheritEnvVars(env) + err = inheritEnvVars(ctx, environ) if err != nil { return err } // Set the temporary directory environment variables - err = setTempDirEnvVars(env, b) + err = setTempDirEnvVars(ctx, environ, b) if err != nil { return err } // Set the proxy related environment variables - err = setProxyEnvVars(env, b) + err = setProxyEnvVars(ctx, environ, b) if err != nil { return err } // Configure environment variables for auth for Terraform to use. - log.Debugf(ctx, "Environment variables for Terraform: %s", strings.Join(maps.Keys(env), ", ")) - err = tf.SetEnv(env) + log.Debugf(ctx, "Environment variables for Terraform: %s", strings.Join(maps.Keys(environ), ", ")) + err = tf.SetEnv(environ) if err != nil { return err } diff --git a/bundle/deploy/terraform/init_test.go b/bundle/deploy/terraform/init_test.go index 5bb5929e6..001e7a220 100644 --- a/bundle/deploy/terraform/init_test.go +++ b/bundle/deploy/terraform/init_test.go @@ -68,7 +68,7 @@ func TestSetTempDirEnvVarsForUnixWithTmpDirSet(t *testing.T) { // compute env env := make(map[string]string, 0) - err := setTempDirEnvVars(env, b) + err := setTempDirEnvVars(context.Background(), env, b) require.NoError(t, err) // Assert that we pass through TMPDIR. @@ -96,7 +96,7 @@ func TestSetTempDirEnvVarsForUnixWithTmpDirNotSet(t *testing.T) { // compute env env := make(map[string]string, 0) - err := setTempDirEnvVars(env, b) + err := setTempDirEnvVars(context.Background(), env, b) require.NoError(t, err) // Assert that we don't pass through TMPDIR. @@ -124,7 +124,7 @@ func TestSetTempDirEnvVarsForWindowWithAllTmpDirEnvVarsSet(t *testing.T) { // compute env env := make(map[string]string, 0) - err := setTempDirEnvVars(env, b) + err := setTempDirEnvVars(context.Background(), env, b) require.NoError(t, err) // assert that we pass through the highest priority env var value @@ -154,7 +154,7 @@ func TestSetTempDirEnvVarsForWindowWithUserProfileAndTempSet(t *testing.T) { // compute env env := make(map[string]string, 0) - err := setTempDirEnvVars(env, b) + err := setTempDirEnvVars(context.Background(), env, b) require.NoError(t, err) // assert that we pass through the highest priority env var value @@ -184,7 +184,7 @@ func TestSetTempDirEnvVarsForWindowWithUserProfileSet(t *testing.T) { // compute env env := make(map[string]string, 0) - err := setTempDirEnvVars(env, b) + err := setTempDirEnvVars(context.Background(), env, b) require.NoError(t, err) // assert that we pass through the user profile @@ -214,11 +214,11 @@ func TestSetTempDirEnvVarsForWindowsWithoutAnyTempDirEnvVarsSet(t *testing.T) { // compute env env := make(map[string]string, 0) - err := setTempDirEnvVars(env, b) + err := setTempDirEnvVars(context.Background(), env, b) require.NoError(t, err) // assert TMP is set to b.CacheDir("tmp") - tmpDir, err := b.CacheDir("tmp") + tmpDir, err := b.CacheDir(context.Background(), "tmp") require.NoError(t, err) assert.Equal(t, map[string]string{ "TMP": tmpDir, @@ -248,7 +248,7 @@ func TestSetProxyEnvVars(t *testing.T) { // No proxy env vars set. clearEnv() env := make(map[string]string, 0) - err := setProxyEnvVars(env, b) + err := setProxyEnvVars(context.Background(), env, b) require.NoError(t, err) assert.Len(t, env, 0) @@ -258,7 +258,7 @@ func TestSetProxyEnvVars(t *testing.T) { t.Setenv("https_proxy", "foo") t.Setenv("no_proxy", "foo") env = make(map[string]string, 0) - err = setProxyEnvVars(env, b) + err = setProxyEnvVars(context.Background(), env, b) require.NoError(t, err) assert.ElementsMatch(t, []string{"HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY"}, maps.Keys(env)) @@ -268,7 +268,7 @@ func TestSetProxyEnvVars(t *testing.T) { t.Setenv("HTTPS_PROXY", "foo") t.Setenv("NO_PROXY", "foo") env = make(map[string]string, 0) - err = setProxyEnvVars(env, b) + err = setProxyEnvVars(context.Background(), env, b) require.NoError(t, err) assert.ElementsMatch(t, []string{"HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY"}, maps.Keys(env)) } @@ -277,14 +277,16 @@ func TestInheritEnvVars(t *testing.T) { env := map[string]string{} t.Setenv("HOME", "/home/testuser") + t.Setenv("PATH", "/foo:/bar") t.Setenv("TF_CLI_CONFIG_FILE", "/tmp/config.tfrc") - err := inheritEnvVars(env) + err := inheritEnvVars(context.Background(), env) require.NoError(t, err) require.Equal(t, map[string]string{ "HOME": "/home/testuser", + "PATH": "/foo:/bar", "TF_CLI_CONFIG_FILE": "/tmp/config.tfrc", }, env) } diff --git a/bundle/deploy/terraform/interpolate.go b/bundle/deploy/terraform/interpolate.go index dd1dcbb88..ea3c99aa1 100644 --- a/bundle/deploy/terraform/interpolate.go +++ b/bundle/deploy/terraform/interpolate.go @@ -25,6 +25,9 @@ func interpolateTerraformResourceIdentifiers(path string, lookup map[string]stri case "experiments": path = strings.Join(append([]string{"databricks_mlflow_experiment"}, parts[2:]...), interpolation.Delimiter) return fmt.Sprintf("${%s}", path), nil + case "model_serving_endpoints": + path = strings.Join(append([]string{"databricks_model_serving"}, parts[2:]...), interpolation.Delimiter) + return fmt.Sprintf("${%s}", path), nil default: panic("TODO: " + parts[1]) } diff --git a/bundle/deploy/terraform/plan.go b/bundle/deploy/terraform/plan.go index a725b4aa9..ff841148c 100644 --- a/bundle/deploy/terraform/plan.go +++ b/bundle/deploy/terraform/plan.go @@ -40,7 +40,7 @@ func (p *plan) Apply(ctx context.Context, b *bundle.Bundle) error { } // Persist computed plan - tfDir, err := Dir(b) + tfDir, err := Dir(ctx, b) if err != nil { return err } diff --git a/bundle/deploy/terraform/state_pull.go b/bundle/deploy/terraform/state_pull.go index e5a42d89b..6dd12ccfc 100644 --- a/bundle/deploy/terraform/state_pull.go +++ b/bundle/deploy/terraform/state_pull.go @@ -25,7 +25,7 @@ func (l *statePull) Apply(ctx context.Context, b *bundle.Bundle) error { return err } - dir, err := Dir(b) + dir, err := Dir(ctx, b) if err != nil { return err } diff --git a/bundle/deploy/terraform/state_push.go b/bundle/deploy/terraform/state_push.go index 0b4c5dbfa..ae1d8b8b3 100644 --- a/bundle/deploy/terraform/state_push.go +++ b/bundle/deploy/terraform/state_push.go @@ -22,7 +22,7 @@ func (l *statePush) Apply(ctx context.Context, b *bundle.Bundle) error { return err } - dir, err := Dir(b) + dir, err := Dir(ctx, b) if err != nil { return err } @@ -32,6 +32,7 @@ func (l *statePush) Apply(ctx context.Context, b *bundle.Bundle) error { if err != nil { return err } + defer local.Close() // Upload state file from local cache directory to filer. log.Infof(ctx, "Writing local state file to remote state directory") diff --git a/bundle/deploy/terraform/write.go b/bundle/deploy/terraform/write.go index b40a70531..eca79ad21 100644 --- a/bundle/deploy/terraform/write.go +++ b/bundle/deploy/terraform/write.go @@ -16,12 +16,13 @@ func (w *write) Name() string { } func (w *write) Apply(ctx context.Context, b *bundle.Bundle) error { - dir, err := Dir(b) + dir, err := Dir(ctx, b) if err != nil { return err } - root := BundleToTerraform(&b.Config) + root, noResources := BundleToTerraform(&b.Config) + b.TerraformHasNoResources = noResources f, err := os.Create(filepath.Join(dir, "bundle.tf.json")) if err != nil { return err diff --git a/bundle/env/env.go b/bundle/env/env.go new file mode 100644 index 000000000..ed2a13c75 --- /dev/null +++ b/bundle/env/env.go @@ -0,0 +1,18 @@ +package env + +import ( + "context" + + envlib "github.com/databricks/cli/libs/env" +) + +// Return the value of the first environment variable that is set. +func get(ctx context.Context, variables []string) (string, bool) { + for _, v := range variables { + value, ok := envlib.Lookup(ctx, v) + if ok { + return value, true + } + } + return "", false +} diff --git a/bundle/env/env_test.go b/bundle/env/env_test.go new file mode 100644 index 000000000..d900242e0 --- /dev/null +++ b/bundle/env/env_test.go @@ -0,0 +1,44 @@ +package env + +import ( + "context" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetWithRealEnvSingleVariable(t *testing.T) { + testutil.CleanupEnvironment(t) + t.Setenv("v1", "foo") + + v, ok := get(context.Background(), []string{"v1"}) + require.True(t, ok) + assert.Equal(t, "foo", v) + + // Not set. + v, ok = get(context.Background(), []string{"v2"}) + require.False(t, ok) + assert.Equal(t, "", v) +} + +func TestGetWithRealEnvMultipleVariables(t *testing.T) { + testutil.CleanupEnvironment(t) + t.Setenv("v1", "foo") + + for _, vars := range [][]string{ + {"v1", "v2", "v3"}, + {"v2", "v3", "v1"}, + {"v3", "v1", "v2"}, + } { + v, ok := get(context.Background(), vars) + require.True(t, ok) + assert.Equal(t, "foo", v) + } + + // Not set. + v, ok := get(context.Background(), []string{"v2", "v3", "v4"}) + require.False(t, ok) + assert.Equal(t, "", v) +} diff --git a/bundle/env/includes.go b/bundle/env/includes.go new file mode 100644 index 000000000..4ade01877 --- /dev/null +++ b/bundle/env/includes.go @@ -0,0 +1,14 @@ +package env + +import "context" + +// IncludesVariable names the environment variable that holds additional configuration paths to include +// during bundle configuration loading. Also see `bundle/config/mutator/process_root_includes.go`. +const IncludesVariable = "DATABRICKS_BUNDLE_INCLUDES" + +// Includes returns the bundle Includes environment variable. +func Includes(ctx context.Context) (string, bool) { + return get(ctx, []string{ + IncludesVariable, + }) +} diff --git a/bundle/env/includes_test.go b/bundle/env/includes_test.go new file mode 100644 index 000000000..d9366a59f --- /dev/null +++ b/bundle/env/includes_test.go @@ -0,0 +1,28 @@ +package env + +import ( + "context" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/stretchr/testify/assert" +) + +func TestIncludes(t *testing.T) { + ctx := context.Background() + + testutil.CleanupEnvironment(t) + + t.Run("set", func(t *testing.T) { + t.Setenv("DATABRICKS_BUNDLE_INCLUDES", "foo") + includes, ok := Includes(ctx) + assert.True(t, ok) + assert.Equal(t, "foo", includes) + }) + + t.Run("not set", func(t *testing.T) { + includes, ok := Includes(ctx) + assert.False(t, ok) + assert.Equal(t, "", includes) + }) +} diff --git a/bundle/env/root.go b/bundle/env/root.go new file mode 100644 index 000000000..e3c2a38ad --- /dev/null +++ b/bundle/env/root.go @@ -0,0 +1,16 @@ +package env + +import "context" + +// RootVariable names the environment variable that holds the bundle root path. +const RootVariable = "DATABRICKS_BUNDLE_ROOT" + +// Root returns the bundle root environment variable. +func Root(ctx context.Context) (string, bool) { + return get(ctx, []string{ + RootVariable, + + // Primary variable name for the bundle root until v0.204.0. + "BUNDLE_ROOT", + }) +} diff --git a/bundle/env/root_test.go b/bundle/env/root_test.go new file mode 100644 index 000000000..fc2d6e206 --- /dev/null +++ b/bundle/env/root_test.go @@ -0,0 +1,43 @@ +package env + +import ( + "context" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/stretchr/testify/assert" +) + +func TestRoot(t *testing.T) { + ctx := context.Background() + + testutil.CleanupEnvironment(t) + + t.Run("first", func(t *testing.T) { + t.Setenv("DATABRICKS_BUNDLE_ROOT", "foo") + root, ok := Root(ctx) + assert.True(t, ok) + assert.Equal(t, "foo", root) + }) + + t.Run("second", func(t *testing.T) { + t.Setenv("BUNDLE_ROOT", "foo") + root, ok := Root(ctx) + assert.True(t, ok) + assert.Equal(t, "foo", root) + }) + + t.Run("both set", func(t *testing.T) { + t.Setenv("DATABRICKS_BUNDLE_ROOT", "first") + t.Setenv("BUNDLE_ROOT", "second") + root, ok := Root(ctx) + assert.True(t, ok) + assert.Equal(t, "first", root) + }) + + t.Run("not set", func(t *testing.T) { + root, ok := Root(ctx) + assert.False(t, ok) + assert.Equal(t, "", root) + }) +} diff --git a/bundle/env/target.go b/bundle/env/target.go new file mode 100644 index 000000000..ac3b48877 --- /dev/null +++ b/bundle/env/target.go @@ -0,0 +1,17 @@ +package env + +import "context" + +// TargetVariable names the environment variable that holds the bundle target to use. +const TargetVariable = "DATABRICKS_BUNDLE_TARGET" + +// Target returns the bundle target environment variable. +func Target(ctx context.Context) (string, bool) { + return get(ctx, []string{ + TargetVariable, + + // Primary variable name for the bundle target until v0.203.2. + // See https://github.com/databricks/cli/pull/670. + "DATABRICKS_BUNDLE_ENV", + }) +} diff --git a/bundle/env/target_test.go b/bundle/env/target_test.go new file mode 100644 index 000000000..0c15bf917 --- /dev/null +++ b/bundle/env/target_test.go @@ -0,0 +1,43 @@ +package env + +import ( + "context" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/stretchr/testify/assert" +) + +func TestTarget(t *testing.T) { + ctx := context.Background() + + testutil.CleanupEnvironment(t) + + t.Run("first", func(t *testing.T) { + t.Setenv("DATABRICKS_BUNDLE_TARGET", "foo") + target, ok := Target(ctx) + assert.True(t, ok) + assert.Equal(t, "foo", target) + }) + + t.Run("second", func(t *testing.T) { + t.Setenv("DATABRICKS_BUNDLE_ENV", "foo") + target, ok := Target(ctx) + assert.True(t, ok) + assert.Equal(t, "foo", target) + }) + + t.Run("both set", func(t *testing.T) { + t.Setenv("DATABRICKS_BUNDLE_TARGET", "first") + t.Setenv("DATABRICKS_BUNDLE_ENV", "second") + target, ok := Target(ctx) + assert.True(t, ok) + assert.Equal(t, "first", target) + }) + + t.Run("not set", func(t *testing.T) { + target, ok := Target(ctx) + assert.False(t, ok) + assert.Equal(t, "", target) + }) +} diff --git a/bundle/env/temp_dir.go b/bundle/env/temp_dir.go new file mode 100644 index 000000000..b91339079 --- /dev/null +++ b/bundle/env/temp_dir.go @@ -0,0 +1,13 @@ +package env + +import "context" + +// TempDirVariable names the environment variable that holds the temporary directory to use. +const TempDirVariable = "DATABRICKS_BUNDLE_TMP" + +// TempDir returns the temporary directory to use. +func TempDir(ctx context.Context) (string, bool) { + return get(ctx, []string{ + TempDirVariable, + }) +} diff --git a/bundle/env/temp_dir_test.go b/bundle/env/temp_dir_test.go new file mode 100644 index 000000000..7659bac6d --- /dev/null +++ b/bundle/env/temp_dir_test.go @@ -0,0 +1,28 @@ +package env + +import ( + "context" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/stretchr/testify/assert" +) + +func TestTempDir(t *testing.T) { + ctx := context.Background() + + testutil.CleanupEnvironment(t) + + t.Run("set", func(t *testing.T) { + t.Setenv("DATABRICKS_BUNDLE_TMP", "foo") + tempDir, ok := TempDir(ctx) + assert.True(t, ok) + assert.Equal(t, "foo", tempDir) + }) + + t.Run("not set", func(t *testing.T) { + tempDir, ok := TempDir(ctx) + assert.False(t, ok) + assert.Equal(t, "", tempDir) + }) +} diff --git a/bundle/libraries/libraries.go b/bundle/libraries/libraries.go index d26768f95..8e2e504c5 100644 --- a/bundle/libraries/libraries.go +++ b/bundle/libraries/libraries.go @@ -56,11 +56,11 @@ func findAllTasks(b *bundle.Bundle) []*jobs.Task { return result } -func FindAllWheelTasks(b *bundle.Bundle) []*jobs.Task { +func FindAllWheelTasksWithLocalLibraries(b *bundle.Bundle) []*jobs.Task { tasks := findAllTasks(b) wheelTasks := make([]*jobs.Task, 0) for _, task := range tasks { - if task.PythonWheelTask != nil { + if task.PythonWheelTask != nil && IsTaskWithLocalLibraries(task) { wheelTasks = append(wheelTasks, task) } } @@ -68,6 +68,27 @@ func FindAllWheelTasks(b *bundle.Bundle) []*jobs.Task { return wheelTasks } +func IsTaskWithLocalLibraries(task *jobs.Task) bool { + for _, l := range task.Libraries { + if isLocalLibrary(&l) { + return true + } + } + + return false +} + +func IsTaskWithWorkspaceLibraries(task *jobs.Task) bool { + for _, l := range task.Libraries { + path := libPath(&l) + if isWorkspacePath(path) { + return true + } + } + + return false +} + func isMissingRequiredLibraries(task *jobs.Task) bool { if task.Libraries != nil { return false @@ -165,8 +186,8 @@ func isRemoteStorageScheme(path string) bool { return false } - // If the path starts with scheme:// format, it's a correct remote storage scheme - return strings.HasPrefix(path, url.Scheme+"://") + // If the path starts with scheme:/ format, it's a correct remote storage scheme + return strings.HasPrefix(path, url.Scheme+":/") } diff --git a/bundle/libraries/libraries_test.go b/bundle/libraries/libraries_test.go index 050efe749..7ff1609ab 100644 --- a/bundle/libraries/libraries_test.go +++ b/bundle/libraries/libraries_test.go @@ -16,6 +16,7 @@ var testCases map[string]bool = map[string]bool{ "file://path/to/package": true, "C:\\path\\to\\package": true, "dbfs://path/to/package": false, + "dbfs:/path/to/package": false, "s3://path/to/package": false, "abfss://path/to/package": false, } diff --git a/bundle/python/wheel_task_wrappers.go b/bundle/python/wheel_task_wrappers.go index df142e4cf..19df1f7e5 100644 --- a/bundle/python/wheel_task_wrappers.go +++ b/bundle/python/wheel_task_wrappers.go @@ -7,6 +7,7 @@ import ( "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config/mutator" + "github.com/databricks/cli/bundle/libraries" jobs_utils "github.com/databricks/cli/libs/jobs" "github.com/databricks/databricks-sdk-go/service/jobs" ) @@ -70,10 +71,13 @@ func (t *pythonTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (strin func (t *pythonTrampoline) GetTasks(b *bundle.Bundle) []jobs_utils.TaskWithJobKey { return jobs_utils.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool { - return task.PythonWheelTask != nil + return task.PythonWheelTask != nil && needsTrampoline(task) }) } +func needsTrampoline(task *jobs.Task) bool { + return libraries.IsTaskWithWorkspaceLibraries(task) +} func (t *pythonTrampoline) GetTemplateData(_ *bundle.Bundle, task *jobs.Task) (map[string]any, error) { params, err := t.generateParameters(task.PythonWheelTask) if err != nil { diff --git a/bundle/python/wheel_task_wrappers_test.go b/bundle/python/wheel_task_wrappers_test.go index a9f57db8e..a7448f234 100644 --- a/bundle/python/wheel_task_wrappers_test.go +++ b/bundle/python/wheel_task_wrappers_test.go @@ -9,6 +9,7 @@ import ( "github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config/paths" "github.com/databricks/cli/bundle/config/resources" + "github.com/databricks/databricks-sdk-go/service/compute" "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/stretchr/testify/require" ) @@ -82,11 +83,21 @@ func TestTransformFiltersWheelTasksOnly(t *testing.T) { { TaskKey: "key1", PythonWheelTask: &jobs.PythonWheelTask{}, + Libraries: []compute.Library{ + {Whl: "/Workspace/Users/test@test.com/bundle/dist/test.whl"}, + }, }, { TaskKey: "key2", NotebookTask: &jobs.NotebookTask{}, }, + { + TaskKey: "key3", + PythonWheelTask: &jobs.PythonWheelTask{}, + Libraries: []compute.Library{ + {Whl: "dbfs:/FileStore/dist/test.whl"}, + }, + }, }, }, }, diff --git a/bundle/root.go b/bundle/root.go index 46f63e134..7518bf5fc 100644 --- a/bundle/root.go +++ b/bundle/root.go @@ -1,21 +1,21 @@ package bundle import ( + "context" "fmt" "os" "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/env" "github.com/databricks/cli/folders" ) -const envBundleRoot = "BUNDLE_ROOT" - -// getRootEnv returns the value of the `BUNDLE_ROOT` environment variable +// getRootEnv returns the value of the bundle root environment variable // if it set and is a directory. If the environment variable is set but // is not a directory, it returns an error. If the environment variable is // not set, it returns an empty string. -func getRootEnv() (string, error) { - path, ok := os.LookupEnv(envBundleRoot) +func getRootEnv(ctx context.Context) (string, error) { + path, ok := env.Root(ctx) if !ok { return "", nil } @@ -24,7 +24,7 @@ func getRootEnv() (string, error) { err = fmt.Errorf("not a directory") } if err != nil { - return "", fmt.Errorf(`invalid bundle root %s="%s": %w`, envBundleRoot, path, err) + return "", fmt.Errorf(`invalid bundle root %s="%s": %w`, env.RootVariable, path, err) } return path, nil } @@ -48,8 +48,8 @@ func getRootWithTraversal() (string, error) { } // mustGetRoot returns a bundle root or an error if one cannot be found. -func mustGetRoot() (string, error) { - path, err := getRootEnv() +func mustGetRoot(ctx context.Context) (string, error) { + path, err := getRootEnv(ctx) if path != "" || err != nil { return path, err } @@ -57,9 +57,9 @@ func mustGetRoot() (string, error) { } // tryGetRoot returns a bundle root or an empty string if one cannot be found. -func tryGetRoot() (string, error) { +func tryGetRoot(ctx context.Context) (string, error) { // Note: an invalid value in the environment variable is still an error. - path, err := getRootEnv() + path, err := getRootEnv(ctx) if path != "" || err != nil { return path, err } diff --git a/bundle/root_test.go b/bundle/root_test.go index 0c4c46aaf..88113546c 100644 --- a/bundle/root_test.go +++ b/bundle/root_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/env" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -32,49 +33,55 @@ func chdir(t *testing.T, dir string) string { } func TestRootFromEnv(t *testing.T) { + ctx := context.Background() dir := t.TempDir() - t.Setenv(envBundleRoot, dir) + t.Setenv(env.RootVariable, dir) // It should pull the root from the environment variable. - root, err := mustGetRoot() + root, err := mustGetRoot(ctx) require.NoError(t, err) require.Equal(t, root, dir) } func TestRootFromEnvDoesntExist(t *testing.T) { + ctx := context.Background() dir := t.TempDir() - t.Setenv(envBundleRoot, filepath.Join(dir, "doesntexist")) + t.Setenv(env.RootVariable, filepath.Join(dir, "doesntexist")) // It should pull the root from the environment variable. - _, err := mustGetRoot() + _, err := mustGetRoot(ctx) require.Errorf(t, err, "invalid bundle root") } func TestRootFromEnvIsFile(t *testing.T) { + ctx := context.Background() dir := t.TempDir() f, err := os.Create(filepath.Join(dir, "invalid")) require.NoError(t, err) f.Close() - t.Setenv(envBundleRoot, f.Name()) + t.Setenv(env.RootVariable, f.Name()) // It should pull the root from the environment variable. - _, err = mustGetRoot() + _, err = mustGetRoot(ctx) require.Errorf(t, err, "invalid bundle root") } func TestRootIfEnvIsEmpty(t *testing.T) { + ctx := context.Background() dir := "" - t.Setenv(envBundleRoot, dir) + t.Setenv(env.RootVariable, dir) // It should pull the root from the environment variable. - _, err := mustGetRoot() + _, err := mustGetRoot(ctx) require.Errorf(t, err, "invalid bundle root") } func TestRootLookup(t *testing.T) { + ctx := context.Background() + // Have to set then unset to allow the testing package to revert it to its original value. - t.Setenv(envBundleRoot, "") - os.Unsetenv(envBundleRoot) + t.Setenv(env.RootVariable, "") + os.Unsetenv(env.RootVariable) chdir(t, t.TempDir()) @@ -89,27 +96,30 @@ func TestRootLookup(t *testing.T) { // It should find the project root from $PWD. wd := chdir(t, "./a/b/c") - root, err := mustGetRoot() + root, err := mustGetRoot(ctx) require.NoError(t, err) require.Equal(t, wd, root) } func TestRootLookupError(t *testing.T) { + ctx := context.Background() + // Have to set then unset to allow the testing package to revert it to its original value. - t.Setenv(envBundleRoot, "") - os.Unsetenv(envBundleRoot) + t.Setenv(env.RootVariable, "") + os.Unsetenv(env.RootVariable) // It can't find a project root from a temporary directory. _ = chdir(t, t.TempDir()) - _, err := mustGetRoot() + _, err := mustGetRoot(ctx) require.ErrorContains(t, err, "unable to locate bundle root") } func TestLoadYamlWhenIncludesEnvPresent(t *testing.T) { + ctx := context.Background() chdir(t, filepath.Join(".", "tests", "basic")) - t.Setenv(ExtraIncludePathsKey, "test") + t.Setenv(env.IncludesVariable, "test") - bundle, err := MustLoad(context.Background()) + bundle, err := MustLoad(ctx) assert.NoError(t, err) assert.Equal(t, "basic", bundle.Config.Bundle.Name) @@ -119,30 +129,33 @@ func TestLoadYamlWhenIncludesEnvPresent(t *testing.T) { } func TestLoadDefautlBundleWhenNoYamlAndRootAndIncludesEnvPresent(t *testing.T) { + ctx := context.Background() dir := t.TempDir() chdir(t, dir) - t.Setenv(envBundleRoot, dir) - t.Setenv(ExtraIncludePathsKey, "test") + t.Setenv(env.RootVariable, dir) + t.Setenv(env.IncludesVariable, "test") - bundle, err := MustLoad(context.Background()) + bundle, err := MustLoad(ctx) assert.NoError(t, err) assert.Equal(t, dir, bundle.Config.Path) } func TestErrorIfNoYamlNoRootEnvAndIncludesEnvPresent(t *testing.T) { + ctx := context.Background() dir := t.TempDir() chdir(t, dir) - t.Setenv(ExtraIncludePathsKey, "test") + t.Setenv(env.IncludesVariable, "test") - _, err := MustLoad(context.Background()) + _, err := MustLoad(ctx) assert.Error(t, err) } func TestErrorIfNoYamlNoIncludesEnvAndRootEnvPresent(t *testing.T) { + ctx := context.Background() dir := t.TempDir() chdir(t, dir) - t.Setenv(envBundleRoot, dir) + t.Setenv(env.RootVariable, dir) - _, err := MustLoad(context.Background()) + _, err := MustLoad(ctx) assert.Error(t, err) } diff --git a/bundle/run/job.go b/bundle/run/job.go index f152a17d0..319cd1464 100644 --- a/bundle/run/job.go +++ b/bundle/run/job.go @@ -95,6 +95,13 @@ type jobRunner struct { job *resources.Job } +func (r *jobRunner) Name() string { + if r.job == nil || r.job.JobSettings == nil { + return "" + } + return r.job.JobSettings.Name +} + func isFailed(task jobs.RunTask) bool { return task.State.LifeCycleState == jobs.RunLifeCycleStateInternalError || (task.State.LifeCycleState == jobs.RunLifeCycleStateTerminated && diff --git a/bundle/run/keys.go b/bundle/run/keys.go index c8b7a2b5b..76ec50ac8 100644 --- a/bundle/run/keys.go +++ b/bundle/run/keys.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/databricks/cli/bundle" + "golang.org/x/exp/maps" ) // RunnerLookup maps identifiers to a list of workloads that match that identifier. @@ -32,18 +33,20 @@ func ResourceKeys(b *bundle.Bundle) (keyOnly RunnerLookup, keyWithType RunnerLoo return } -// ResourceCompletions returns a list of keys that unambiguously reference resources in the bundle. -func ResourceCompletions(b *bundle.Bundle) []string { - seen := make(map[string]bool) - comps := []string{} +// ResourceCompletionMap returns a map of resource keys to their respective names. +func ResourceCompletionMap(b *bundle.Bundle) map[string]string { + out := make(map[string]string) keyOnly, keyWithType := ResourceKeys(b) + // Keep track of resources we have seen by their fully qualified key. + seen := make(map[string]bool) + // First add resources that can be identified by key alone. for k, v := range keyOnly { // Invariant: len(v) >= 1. See [ResourceKeys]. if len(v) == 1 { seen[v[0].Key()] = true - comps = append(comps, k) + out[k] = v[0].Name() } } @@ -54,8 +57,13 @@ func ResourceCompletions(b *bundle.Bundle) []string { if ok { continue } - comps = append(comps, k) + out[k] = v[0].Name() } - return comps + return out +} + +// ResourceCompletions returns a list of keys that unambiguously reference resources in the bundle. +func ResourceCompletions(b *bundle.Bundle) []string { + return maps.Keys(ResourceCompletionMap(b)) } diff --git a/bundle/run/output/job.go b/bundle/run/output/job.go index 4bea4c7ad..6199ac2f7 100644 --- a/bundle/run/output/job.go +++ b/bundle/run/output/job.go @@ -60,7 +60,7 @@ func GetJobOutput(ctx context.Context, w *databricks.WorkspaceClient, runId int6 return nil, err } result := &JobOutput{ - TaskOutputs: make([]TaskOutput, len(jobRun.Tasks)), + TaskOutputs: make([]TaskOutput, 0), } for _, task := range jobRun.Tasks { jobRunOutput, err := w.Jobs.GetRunOutput(ctx, jobs.GetRunOutputRequest{ @@ -69,7 +69,11 @@ func GetJobOutput(ctx context.Context, w *databricks.WorkspaceClient, runId int6 if err != nil { return nil, err } - task := TaskOutput{TaskKey: task.TaskKey, Output: toRunOutput(jobRunOutput), EndTime: task.EndTime} + out := toRunOutput(jobRunOutput) + if out == nil { + continue + } + task := TaskOutput{TaskKey: task.TaskKey, Output: out, EndTime: task.EndTime} result.TaskOutputs = append(result.TaskOutputs, task) } return result, nil diff --git a/bundle/run/pipeline.go b/bundle/run/pipeline.go index 7b82c3eae..216712d30 100644 --- a/bundle/run/pipeline.go +++ b/bundle/run/pipeline.go @@ -136,6 +136,13 @@ type pipelineRunner struct { pipeline *resources.Pipeline } +func (r *pipelineRunner) Name() string { + if r.pipeline == nil || r.pipeline.PipelineSpec == nil { + return "" + } + return r.pipeline.PipelineSpec.Name +} + func (r *pipelineRunner) Run(ctx context.Context, opts *Options) (output.RunOutput, error) { var pipelineID = r.pipeline.ID diff --git a/bundle/run/runner.go b/bundle/run/runner.go index 227e12d97..7d3c2c297 100644 --- a/bundle/run/runner.go +++ b/bundle/run/runner.go @@ -21,6 +21,9 @@ type Runner interface { // This is used for showing the user hints w.r.t. disambiguation. Key() string + // Name returns the resource's name, if defined. + Name() string + // Run the underlying worklow. Run(ctx context.Context, opts *Options) (output.RunOutput, error) } diff --git a/bundle/schema/docs/bundle_descriptions.json b/bundle/schema/docs/bundle_descriptions.json index 84f0492fb..98f3cf8d0 100644 --- a/bundle/schema/docs/bundle_descriptions.json +++ b/bundle/schema/docs/bundle_descriptions.json @@ -1441,6 +1441,86 @@ } } }, + "model_serving_endpoints": { + "description": "List of Model Serving Endpoints", + "additionalproperties": { + "description": "", + "properties": { + "name": { + "description": "The name of the model serving endpoint. This field is required and must be unique across a workspace. An endpoint name can consist of alphanumeric characters, dashes, and underscores. NOTE: Changing this name will delete the existing endpoint and create a new endpoint with the update name." + }, + "permissions": { + "description": "", + "items": { + "description": "", + "properties": { + "group_name": { + "description": "" + }, + "level": { + "description": "" + }, + "service_principal_name": { + "description": "" + }, + "user_name": { + "description": "" + } + } + } + }, + "config": { + "description": "The model serving endpoint configuration.", + "properties": { + "properties": { + "served_models": { + "description": "Each block represents a served model for the endpoint to serve. A model serving endpoint can have up to 10 served models.", + "items": { + "description": "", + "properties" : { + "name": { + "description": "The name of a served model. It must be unique across an endpoint. If not specified, this field will default to modelname-modelversion. A served model name can consist of alphanumeric characters, dashes, and underscores." + }, + "model_name": { + "description": "The name of the model in Databricks Model Registry to be served." + }, + "model_version": { + "description": "The version of the model in Databricks Model Registry to be served." + }, + "workload_size": { + "description": "The workload size of the served model. The workload size corresponds to a range of provisioned concurrency that the compute will autoscale between. A single unit of provisioned concurrency can process one request at a time. Valid workload sizes are \"Small\" (4 - 4 provisioned concurrency), \"Medium\" (8 - 16 provisioned concurrency), and \"Large\" (16 - 64 provisioned concurrency)." + }, + "scale_to_zero_enabled": { + "description": "Whether the compute resources for the served model should scale down to zero. If scale-to-zero is enabled, the lower bound of the provisioned concurrency for each workload size will be 0." + } + } + } + }, + "traffic_config": { + "description": "A single block represents the traffic split configuration amongst the served models.", + "properties": { + "routes": { + "description": "Each block represents a route that defines traffic to each served model. Each served_models block needs to have a corresponding routes block.", + "items": { + "description": "", + "properties": { + "served_model_name": { + "description": "The name of the served model this route configures traffic for. This needs to match the name of a served_models block." + }, + "traffic_percentage": { + "description": "The percentage of endpoint traffic to send to this route. It must be an integer between 0 and 100 inclusive." + } + } + } + } + } + } + } + } + } + } + } + }, "pipelines": { "description": "List of DLT pipelines", "additionalproperties": { diff --git a/bundle/schema/openapi.go b/bundle/schema/openapi.go index b0d676576..1a8b76ed9 100644 --- a/bundle/schema/openapi.go +++ b/bundle/schema/openapi.go @@ -210,6 +210,19 @@ func (reader *OpenapiReader) modelsDocs() (*Docs, error) { return modelsDocs, nil } +func (reader *OpenapiReader) modelServingEndpointsDocs() (*Docs, error) { + modelServingEndpointsSpecSchema, err := reader.readResolvedSchema(SchemaPathPrefix + "serving.CreateServingEndpoint") + if err != nil { + return nil, err + } + modelServingEndpointsDocs := schemaToDocs(modelServingEndpointsSpecSchema) + modelServingEndpointsAllDocs := &Docs{ + Description: "List of Model Serving Endpoints", + AdditionalProperties: modelServingEndpointsDocs, + } + return modelServingEndpointsAllDocs, nil +} + func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) { jobsDocs, err := reader.jobsDocs() if err != nil { @@ -227,14 +240,19 @@ func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) { if err != nil { return nil, err } + modelServingEndpointsDocs, err := reader.modelServingEndpointsDocs() + if err != nil { + return nil, err + } return &Docs{ Description: "Collection of Databricks resources to deploy.", Properties: map[string]*Docs{ - "jobs": jobsDocs, - "pipelines": pipelinesDocs, - "experiments": experimentsDocs, - "models": modelsDocs, + "jobs": jobsDocs, + "pipelines": pipelinesDocs, + "experiments": experimentsDocs, + "models": modelsDocs, + "model_serving_endpoints": modelServingEndpointsDocs, }, }, nil } diff --git a/bundle/tests/bundle/python_wheel_dbfs_lib/bundle.yml b/bundle/tests/bundle/python_wheel_dbfs_lib/bundle.yml index 54577d658..07f4957bb 100644 --- a/bundle/tests/bundle/python_wheel_dbfs_lib/bundle.yml +++ b/bundle/tests/bundle/python_wheel_dbfs_lib/bundle.yml @@ -12,4 +12,4 @@ resources: package_name: "my_test_code" entry_point: "run" libraries: - - whl: dbfs://path/to/dist/mywheel.whl + - whl: dbfs:/path/to/dist/mywheel.whl diff --git a/bundle/tests/model_serving_endpoint/databricks.yml b/bundle/tests/model_serving_endpoint/databricks.yml new file mode 100644 index 000000000..e4fb54a1f --- /dev/null +++ b/bundle/tests/model_serving_endpoint/databricks.yml @@ -0,0 +1,38 @@ +resources: + model_serving_endpoints: + my_model_serving_endpoint: + name: "my-endpoint" + config: + served_models: + - model_name: "model-name" + model_version: "1" + workload_size: "Small" + scale_to_zero_enabled: true + traffic_config: + routes: + - served_model_name: "model-name-1" + traffic_percentage: 100 + permissions: + - level: CAN_QUERY + group_name: users + +targets: + development: + mode: development + resources: + model_serving_endpoints: + my_model_serving_endpoint: + name: "my-dev-endpoint" + + staging: + resources: + model_serving_endpoints: + my_model_serving_endpoint: + name: "my-staging-endpoint" + + production: + mode: production + resources: + model_serving_endpoints: + my_model_serving_endpoint: + name: "my-prod-endpoint" diff --git a/bundle/tests/model_serving_endpoint_test.go b/bundle/tests/model_serving_endpoint_test.go new file mode 100644 index 000000000..bfa1a31b4 --- /dev/null +++ b/bundle/tests/model_serving_endpoint_test.go @@ -0,0 +1,48 @@ +package config_tests + +import ( + "path/filepath" + "testing" + + "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/config/resources" + "github.com/stretchr/testify/assert" +) + +func assertExpected(t *testing.T, p *resources.ModelServingEndpoint) { + assert.Equal(t, "model_serving_endpoint/databricks.yml", filepath.ToSlash(p.ConfigFilePath)) + assert.Equal(t, "model-name", p.Config.ServedModels[0].ModelName) + assert.Equal(t, "1", p.Config.ServedModels[0].ModelVersion) + assert.Equal(t, "model-name-1", p.Config.TrafficConfig.Routes[0].ServedModelName) + assert.Equal(t, 100, p.Config.TrafficConfig.Routes[0].TrafficPercentage) + assert.Equal(t, "users", p.Permissions[0].GroupName) + assert.Equal(t, "CAN_QUERY", p.Permissions[0].Level) +} + +func TestModelServingEndpointDevelopment(t *testing.T) { + b := loadTarget(t, "./model_serving_endpoint", "development") + assert.Len(t, b.Config.Resources.ModelServingEndpoints, 1) + assert.Equal(t, b.Config.Bundle.Mode, config.Development) + + p := b.Config.Resources.ModelServingEndpoints["my_model_serving_endpoint"] + assert.Equal(t, "my-dev-endpoint", p.Name) + assertExpected(t, p) +} + +func TestModelServingEndpointStaging(t *testing.T) { + b := loadTarget(t, "./model_serving_endpoint", "staging") + assert.Len(t, b.Config.Resources.ModelServingEndpoints, 1) + + p := b.Config.Resources.ModelServingEndpoints["my_model_serving_endpoint"] + assert.Equal(t, "my-staging-endpoint", p.Name) + assertExpected(t, p) +} + +func TestModelServingEndpointProduction(t *testing.T) { + b := loadTarget(t, "./model_serving_endpoint", "production") + assert.Len(t, b.Config.Resources.ModelServingEndpoints, 1) + + p := b.Config.Resources.ModelServingEndpoints["my_model_serving_endpoint"] + assert.Equal(t, "my-prod-endpoint", p.Name) + assertExpected(t, p) +} diff --git a/bundle/tests/suggest_target_test.go b/bundle/tests/suggest_target_test.go new file mode 100644 index 000000000..924d6a4e1 --- /dev/null +++ b/bundle/tests/suggest_target_test.go @@ -0,0 +1,17 @@ +package config_tests + +import ( + "path/filepath" + "testing" + + "github.com/databricks/cli/internal" + "github.com/stretchr/testify/require" +) + +func TestSuggestTargetIfWrongPassed(t *testing.T) { + t.Setenv("BUNDLE_ROOT", filepath.Join("target_overrides", "workspace")) + _, _, err := internal.RequireErrorRun(t, "bundle", "validate", "-e", "incorrect") + require.ErrorContains(t, err, "Available targets:") + require.ErrorContains(t, err, "development") + require.ErrorContains(t, err, "staging") +} diff --git a/cmd/bundle/bundle.go b/cmd/bundle/bundle.go index c933ec9c3..d8382d172 100644 --- a/cmd/bundle/bundle.go +++ b/cmd/bundle/bundle.go @@ -7,7 +7,7 @@ import ( func New() *cobra.Command { cmd := &cobra.Command{ Use: "bundle", - Short: "Databricks Asset Bundles", + Short: "Databricks Asset Bundles\n\nOnline documentation: https://docs.databricks.com/en/dev-tools/bundles", } initVariableFlag(cmd) diff --git a/cmd/bundle/init.go b/cmd/bundle/init.go index 9a11eb257..3038cb7a2 100644 --- a/cmd/bundle/init.go +++ b/cmd/bundle/init.go @@ -74,22 +74,21 @@ func newInitCommand() *cobra.Command { return template.Materialize(ctx, configFile, templatePath, outputDir) } - // Download the template in a temporary directory - tmpDir := os.TempDir() - templateURL := templatePath - repoDir := filepath.Join(tmpDir, repoName(templateURL)) - err := os.MkdirAll(repoDir, 0755) + // Create a temporary directory with the name of the repository. The '*' + // character is replaced by a random string in the generated temporary directory. + repoDir, err := os.MkdirTemp("", repoName(templatePath)+"-*") if err != nil { return err } // TODO: Add automated test that the downloaded git repo is cleaned up. - err = git.Clone(ctx, templateURL, "", repoDir) + // Clone the repository in the temporary directory + err = git.Clone(ctx, templatePath, "", repoDir) if err != nil { return err } - defer os.RemoveAll(templateDir) + // Clean up downloaded repository once the template is materialized. + defer os.RemoveAll(repoDir) return template.Materialize(ctx, configFile, filepath.Join(repoDir, templateDir), outputDir) } - return cmd } diff --git a/cmd/bundle/run.go b/cmd/bundle/run.go index 28b9ae7cd..b5a60ee15 100644 --- a/cmd/bundle/run.go +++ b/cmd/bundle/run.go @@ -9,6 +9,7 @@ import ( "github.com/databricks/cli/bundle/phases" "github.com/databricks/cli/bundle/run" "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/flags" "github.com/spf13/cobra" ) @@ -16,9 +17,9 @@ import ( func newRunCommand() *cobra.Command { cmd := &cobra.Command{ Use: "run [flags] KEY", - Short: "Run a workload (e.g. a job or a pipeline)", + Short: "Run a resource (e.g. a job or a pipeline)", - Args: cobra.ExactArgs(1), + Args: cobra.MaximumNArgs(1), PreRunE: ConfigureBundleWithVariables, } @@ -29,9 +30,10 @@ func newRunCommand() *cobra.Command { cmd.Flags().BoolVar(&noWait, "no-wait", false, "Don't wait for the run to complete.") cmd.RunE = func(cmd *cobra.Command, args []string) error { - b := bundle.Get(cmd.Context()) + ctx := cmd.Context() + b := bundle.Get(ctx) - err := bundle.Apply(cmd.Context(), b, bundle.Seq( + err := bundle.Apply(ctx, b, bundle.Seq( phases.Initialize(), terraform.Interpolate(), terraform.Write(), @@ -42,13 +44,31 @@ func newRunCommand() *cobra.Command { return err } + // If no arguments are specified, prompt the user to select something to run. + if len(args) == 0 && cmdio.IsInteractive(ctx) { + // Invert completions from KEY -> NAME, to NAME -> KEY. + inv := make(map[string]string) + for k, v := range run.ResourceCompletionMap(b) { + inv[v] = k + } + id, err := cmdio.Select(ctx, inv, "Resource to run") + if err != nil { + return err + } + args = append(args, id) + } + + if len(args) != 1 { + return fmt.Errorf("expected a KEY of the resource to run") + } + runner, err := run.Find(b, args[0]) if err != nil { return err } runOptions.NoWait = noWait - output, err := runner.Run(cmd.Context(), &runOptions) + output, err := runner.Run(ctx, &runOptions) if err != nil { return err } diff --git a/cmd/bundle/sync.go b/cmd/bundle/sync.go index be45626a3..6d6a6f5a3 100644 --- a/cmd/bundle/sync.go +++ b/cmd/bundle/sync.go @@ -18,12 +18,12 @@ type syncFlags struct { } func (f *syncFlags) syncOptionsFromBundle(cmd *cobra.Command, b *bundle.Bundle) (*sync.SyncOptions, error) { - cacheDir, err := b.CacheDir() + cacheDir, err := b.CacheDir(cmd.Context()) if err != nil { return nil, fmt.Errorf("cannot get bundle cache directory: %w", err) } - includes, err := b.GetSyncIncludePatterns() + includes, err := b.GetSyncIncludePatterns(cmd.Context()) if err != nil { return nil, fmt.Errorf("cannot get list of sync includes: %w", err) } diff --git a/cmd/cmd.go b/cmd/cmd.go index 032fde5cd..6dd0f6e21 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "strings" "github.com/databricks/cli/cmd/account" @@ -21,8 +22,8 @@ const ( permissionsGroup = "permissions" ) -func New() *cobra.Command { - cli := root.New() +func New(ctx context.Context) *cobra.Command { + cli := root.New(ctx) // Add account subcommand. cli.AddCommand(account.New()) diff --git a/cmd/configure/configure_test.go b/cmd/configure/configure_test.go index e1ebe916b..cf0505edd 100644 --- a/cmd/configure/configure_test.go +++ b/cmd/configure/configure_test.go @@ -54,7 +54,7 @@ func TestDefaultConfigureNoInteractive(t *testing.T) { }) os.Stdin = inp - cmd := cmd.New() + cmd := cmd.New(ctx) cmd.SetArgs([]string{"configure", "--token", "--host", "https://host"}) err := cmd.ExecuteContext(ctx) @@ -87,7 +87,7 @@ func TestConfigFileFromEnvNoInteractive(t *testing.T) { t.Cleanup(func() { os.Stdin = oldStdin }) os.Stdin = inp - cmd := cmd.New() + cmd := cmd.New(ctx) cmd.SetArgs([]string{"configure", "--token", "--host", "https://host"}) err := cmd.ExecuteContext(ctx) @@ -116,7 +116,7 @@ func TestCustomProfileConfigureNoInteractive(t *testing.T) { t.Cleanup(func() { os.Stdin = oldStdin }) os.Stdin = inp - cmd := cmd.New() + cmd := cmd.New(ctx) cmd.SetArgs([]string{"configure", "--token", "--host", "https://host", "--profile", "CUSTOM"}) err := cmd.ExecuteContext(ctx) diff --git a/cmd/root/auth.go b/cmd/root/auth.go index d4c9a31b9..de5648c65 100644 --- a/cmd/root/auth.go +++ b/cmd/root/auth.go @@ -25,13 +25,57 @@ func initProfileFlag(cmd *cobra.Command) { cmd.RegisterFlagCompletionFunc("profile", databrickscfg.ProfileCompletion) } +func profileFlagValue(cmd *cobra.Command) (string, bool) { + profileFlag := cmd.Flag("profile") + if profileFlag == nil { + return "", false + } + value := profileFlag.Value.String() + return value, value != "" +} + +// Helper function to create an account client or prompt once if the given configuration is not valid. +func accountClientOrPrompt(ctx context.Context, cfg *config.Config, allowPrompt bool) (*databricks.AccountClient, error) { + a, err := databricks.NewAccountClient((*databricks.Config)(cfg)) + if err == nil { + err = a.Config.Authenticate(emptyHttpRequest(ctx)) + } + + prompt := false + if allowPrompt && err != nil && cmdio.IsInteractive(ctx) { + // Prompt to select a profile if the current configuration is not an account client. + prompt = prompt || errors.Is(err, databricks.ErrNotAccountClient) + // Prompt to select a profile if the current configuration doesn't resolve to a credential provider. + prompt = prompt || errors.Is(err, config.ErrCannotConfigureAuth) + } + + if !prompt { + // If we are not prompting, we can return early. + return a, err + } + + // Try picking a profile dynamically if the current configuration is not valid. + profile, err := askForAccountProfile(ctx) + if err != nil { + return nil, err + } + a, err = databricks.NewAccountClient(&databricks.Config{Profile: profile}) + if err == nil { + err = a.Config.Authenticate(emptyHttpRequest(ctx)) + if err != nil { + return nil, err + } + } + return a, nil +} + func MustAccountClient(cmd *cobra.Command, args []string) error { cfg := &config.Config{} - // command-line flag can specify the profile in use - profileFlag := cmd.Flag("profile") - if profileFlag != nil { - cfg.Profile = profileFlag.Value.String() + // The command-line profile flag takes precedence over DATABRICKS_CONFIG_PROFILE. + profile, hasProfileFlag := profileFlagValue(cmd) + if hasProfileFlag { + cfg.Profile = profile } if cfg.Profile == "" { @@ -48,16 +92,8 @@ func MustAccountClient(cmd *cobra.Command, args []string) error { } } -TRY_AUTH: // or try picking a config profile dynamically - a, err := databricks.NewAccountClient((*databricks.Config)(cfg)) - if cmdio.IsInteractive(cmd.Context()) && errors.Is(err, databricks.ErrNotAccountClient) { - profile, err := askForAccountProfile() - if err != nil { - return err - } - cfg = &config.Config{Profile: profile} - goto TRY_AUTH - } + allowPrompt := !hasProfileFlag + a, err := accountClientOrPrompt(cmd.Context(), cfg, allowPrompt) if err != nil { return err } @@ -66,13 +102,48 @@ TRY_AUTH: // or try picking a config profile dynamically return nil } +// Helper function to create a workspace client or prompt once if the given configuration is not valid. +func workspaceClientOrPrompt(ctx context.Context, cfg *config.Config, allowPrompt bool) (*databricks.WorkspaceClient, error) { + w, err := databricks.NewWorkspaceClient((*databricks.Config)(cfg)) + if err == nil { + err = w.Config.Authenticate(emptyHttpRequest(ctx)) + } + + prompt := false + if allowPrompt && err != nil && cmdio.IsInteractive(ctx) { + // Prompt to select a profile if the current configuration is not a workspace client. + prompt = prompt || errors.Is(err, databricks.ErrNotWorkspaceClient) + // Prompt to select a profile if the current configuration doesn't resolve to a credential provider. + prompt = prompt || errors.Is(err, config.ErrCannotConfigureAuth) + } + + if !prompt { + // If we are not prompting, we can return early. + return w, err + } + + // Try picking a profile dynamically if the current configuration is not valid. + profile, err := askForWorkspaceProfile(ctx) + if err != nil { + return nil, err + } + w, err = databricks.NewWorkspaceClient(&databricks.Config{Profile: profile}) + if err == nil { + err = w.Config.Authenticate(emptyHttpRequest(ctx)) + if err != nil { + return nil, err + } + } + return w, nil +} + func MustWorkspaceClient(cmd *cobra.Command, args []string) error { cfg := &config.Config{} - // command-line flag takes precedence over environment variable - profileFlag := cmd.Flag("profile") - if profileFlag != nil { - cfg.Profile = profileFlag.Value.String() + // The command-line profile flag takes precedence over DATABRICKS_CONFIG_PROFILE. + profile, hasProfileFlag := profileFlagValue(cmd) + if hasProfileFlag { + cfg.Profile = profile } // try configuring a bundle @@ -87,24 +158,13 @@ func MustWorkspaceClient(cmd *cobra.Command, args []string) error { cfg = currentBundle.WorkspaceClient().Config } -TRY_AUTH: // or try picking a config profile dynamically + allowPrompt := !hasProfileFlag + w, err := workspaceClientOrPrompt(cmd.Context(), cfg, allowPrompt) + if err != nil { + return err + } + ctx := cmd.Context() - w, err := databricks.NewWorkspaceClient((*databricks.Config)(cfg)) - if err != nil { - return err - } - err = w.Config.Authenticate(emptyHttpRequest(ctx)) - if cmdio.IsInteractive(ctx) && errors.Is(err, config.ErrCannotConfigureAuth) { - profile, err := askForWorkspaceProfile() - if err != nil { - return err - } - cfg = &config.Config{Profile: profile} - goto TRY_AUTH - } - if err != nil { - return err - } ctx = context.WithValue(ctx, &workspaceClient, w) cmd.SetContext(ctx) return nil @@ -121,7 +181,7 @@ func transformLoadError(path string, err error) error { return err } -func askForWorkspaceProfile() (string, error) { +func askForWorkspaceProfile(ctx context.Context) (string, error) { path, err := databrickscfg.GetPath() if err != nil { return "", fmt.Errorf("cannot determine Databricks config file path: %w", err) @@ -136,7 +196,7 @@ func askForWorkspaceProfile() (string, error) { case 1: return profiles[0].Name, nil } - i, _, err := (&promptui.Select{ + i, _, err := cmdio.RunSelect(ctx, &promptui.Select{ Label: fmt.Sprintf("Workspace profiles defined in %s", file), Items: profiles, Searcher: profiles.SearchCaseInsensitive, @@ -147,16 +207,14 @@ func askForWorkspaceProfile() (string, error) { Inactive: `{{.Name}}`, Selected: `{{ "Using workspace profile" | faint }}: {{ .Name | bold }}`, }, - Stdin: os.Stdin, - Stdout: os.Stderr, - }).Run() + }) if err != nil { return "", err } return profiles[i].Name, nil } -func askForAccountProfile() (string, error) { +func askForAccountProfile(ctx context.Context) (string, error) { path, err := databrickscfg.GetPath() if err != nil { return "", fmt.Errorf("cannot determine Databricks config file path: %w", err) @@ -171,7 +229,7 @@ func askForAccountProfile() (string, error) { case 1: return profiles[0].Name, nil } - i, _, err := (&promptui.Select{ + i, _, err := cmdio.RunSelect(ctx, &promptui.Select{ Label: fmt.Sprintf("Account profiles defined in %s", file), Items: profiles, Searcher: profiles.SearchCaseInsensitive, @@ -182,9 +240,7 @@ func askForAccountProfile() (string, error) { Inactive: `{{.Name}}`, Selected: `{{ "Using account profile" | faint }}: {{ .Name | bold }}`, }, - Stdin: os.Stdin, - Stdout: os.Stderr, - }).Run() + }) if err != nil { return "", err } diff --git a/cmd/root/auth_test.go b/cmd/root/auth_test.go index 75d255b58..30fa9a086 100644 --- a/cmd/root/auth_test.go +++ b/cmd/root/auth_test.go @@ -2,9 +2,16 @@ package root import ( "context" + "os" + "path/filepath" "testing" + "time" + "github.com/databricks/cli/internal/testutil" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go/config" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestEmptyHttpRequest(t *testing.T) { @@ -12,3 +19,165 @@ func TestEmptyHttpRequest(t *testing.T) { req := emptyHttpRequest(ctx) assert.Equal(t, req.Context(), ctx) } + +type promptFn func(ctx context.Context, cfg *config.Config, retry bool) (any, error) + +var accountPromptFn = func(ctx context.Context, cfg *config.Config, retry bool) (any, error) { + return accountClientOrPrompt(ctx, cfg, retry) +} + +var workspacePromptFn = func(ctx context.Context, cfg *config.Config, retry bool) (any, error) { + return workspaceClientOrPrompt(ctx, cfg, retry) +} + +func expectPrompts(t *testing.T, fn promptFn, config *config.Config) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + // Channel to pass errors from the prompting function back to the test. + errch := make(chan error, 1) + + ctx, io := cmdio.SetupTest(ctx) + go func() { + defer close(errch) + defer cancel() + _, err := fn(ctx, config, true) + errch <- err + }() + + // Expect a prompt + line, _, err := io.Stderr.ReadLine() + if assert.NoError(t, err, "Expected to read a line from stderr") { + assert.Contains(t, string(line), "Search:") + } else { + // If there was an error reading from stderr, the prompting function must have terminated early. + assert.NoError(t, <-errch) + } +} + +func expectReturns(t *testing.T, fn promptFn, config *config.Config) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + ctx, _ = cmdio.SetupTest(ctx) + client, err := fn(ctx, config, true) + require.NoError(t, err) + require.NotNil(t, client) +} + +func TestAccountClientOrPrompt(t *testing.T) { + testutil.CleanupEnvironment(t) + + dir := t.TempDir() + configFile := filepath.Join(dir, ".databrickscfg") + err := os.WriteFile( + configFile, + []byte(` + [account-1111] + host = https://accounts.azuredatabricks.net/ + account_id = 1111 + token = foobar + + [account-1112] + host = https://accounts.azuredatabricks.net/ + account_id = 1112 + token = foobar + `), + 0755) + require.NoError(t, err) + t.Setenv("DATABRICKS_CONFIG_FILE", configFile) + t.Setenv("PATH", "/nothing") + + t.Run("Prompt if nothing is specified", func(t *testing.T) { + expectPrompts(t, accountPromptFn, &config.Config{}) + }) + + t.Run("Prompt if a workspace host is specified", func(t *testing.T) { + expectPrompts(t, accountPromptFn, &config.Config{ + Host: "https://adb-1234567.89.azuredatabricks.net/", + AccountID: "1234", + Token: "foobar", + }) + }) + + t.Run("Prompt if account ID is not specified", func(t *testing.T) { + expectPrompts(t, accountPromptFn, &config.Config{ + Host: "https://accounts.azuredatabricks.net/", + Token: "foobar", + }) + }) + + t.Run("Prompt if no credential provider can be configured", func(t *testing.T) { + expectPrompts(t, accountPromptFn, &config.Config{ + Host: "https://accounts.azuredatabricks.net/", + AccountID: "1234", + }) + }) + + t.Run("Returns if configuration is valid", func(t *testing.T) { + expectReturns(t, accountPromptFn, &config.Config{ + Host: "https://accounts.azuredatabricks.net/", + AccountID: "1234", + Token: "foobar", + }) + }) + + t.Run("Returns if a valid profile is specified", func(t *testing.T) { + expectReturns(t, accountPromptFn, &config.Config{ + Profile: "account-1111", + }) + }) +} + +func TestWorkspaceClientOrPrompt(t *testing.T) { + testutil.CleanupEnvironment(t) + + dir := t.TempDir() + configFile := filepath.Join(dir, ".databrickscfg") + err := os.WriteFile( + configFile, + []byte(` + [workspace-1111] + host = https://adb-1111.11.azuredatabricks.net/ + token = foobar + + [workspace-1112] + host = https://adb-1112.12.azuredatabricks.net/ + token = foobar + `), + 0755) + require.NoError(t, err) + t.Setenv("DATABRICKS_CONFIG_FILE", configFile) + t.Setenv("PATH", "/nothing") + + t.Run("Prompt if nothing is specified", func(t *testing.T) { + expectPrompts(t, workspacePromptFn, &config.Config{}) + }) + + t.Run("Prompt if an account host is specified", func(t *testing.T) { + expectPrompts(t, workspacePromptFn, &config.Config{ + Host: "https://accounts.azuredatabricks.net/", + AccountID: "1234", + Token: "foobar", + }) + }) + + t.Run("Prompt if no credential provider can be configured", func(t *testing.T) { + expectPrompts(t, workspacePromptFn, &config.Config{ + Host: "https://adb-1111.11.azuredatabricks.net/", + }) + }) + + t.Run("Returns if configuration is valid", func(t *testing.T) { + expectReturns(t, workspacePromptFn, &config.Config{ + Host: "https://adb-1111.11.azuredatabricks.net/", + Token: "foobar", + }) + }) + + t.Run("Returns if a valid profile is specified", func(t *testing.T) { + expectReturns(t, workspacePromptFn, &config.Config{ + Profile: "workspace-1111", + }) + }) +} diff --git a/cmd/root/bundle.go b/cmd/root/bundle.go index 10cce67a4..3f9d90db6 100644 --- a/cmd/root/bundle.go +++ b/cmd/root/bundle.go @@ -2,17 +2,15 @@ package root import ( "context" - "os" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config/mutator" + "github.com/databricks/cli/bundle/env" + envlib "github.com/databricks/cli/libs/env" "github.com/spf13/cobra" "golang.org/x/exp/maps" ) -const envName = "DATABRICKS_BUNDLE_ENV" -const targetName = "DATABRICKS_BUNDLE_TARGET" - // getTarget returns the name of the target to operate in. func getTarget(cmd *cobra.Command) (value string) { // The command line flag takes precedence. @@ -33,13 +31,7 @@ func getTarget(cmd *cobra.Command) (value string) { } // If it's not set, use the environment variable. - target := os.Getenv(targetName) - // If target env is not set with a new variable, try to check for old variable name - // TODO: remove when environments section is not supported anymore - if target == "" { - target = os.Getenv(envName) - } - + target, _ := env.Target(cmd.Context()) return target } @@ -54,7 +46,7 @@ func getProfile(cmd *cobra.Command) (value string) { } // If it's not set, use the environment variable. - return os.Getenv("DATABRICKS_CONFIG_PROFILE") + return envlib.Get(cmd.Context(), "DATABRICKS_CONFIG_PROFILE") } // loadBundle loads the bundle configuration and applies default mutators. diff --git a/cmd/root/bundle_test.go b/cmd/root/bundle_test.go index 09b33d589..3f9641b7e 100644 --- a/cmd/root/bundle_test.go +++ b/cmd/root/bundle_test.go @@ -9,6 +9,7 @@ import ( "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/internal/testutil" "github.com/spf13/cobra" "github.com/stretchr/testify/assert" ) @@ -56,6 +57,8 @@ func setup(t *testing.T, cmd *cobra.Command, host string) *bundle.Bundle { } func TestBundleConfigureDefault(t *testing.T) { + testutil.CleanupEnvironment(t) + cmd := emptyCommand(t) b := setup(t, cmd, "https://x.com") assert.NotPanics(t, func() { @@ -64,6 +67,8 @@ func TestBundleConfigureDefault(t *testing.T) { } func TestBundleConfigureWithMultipleMatches(t *testing.T) { + testutil.CleanupEnvironment(t) + cmd := emptyCommand(t) b := setup(t, cmd, "https://a.com") assert.Panics(t, func() { @@ -72,6 +77,8 @@ func TestBundleConfigureWithMultipleMatches(t *testing.T) { } func TestBundleConfigureWithNonExistentProfileFlag(t *testing.T) { + testutil.CleanupEnvironment(t) + cmd := emptyCommand(t) cmd.Flag("profile").Value.Set("NOEXIST") @@ -82,6 +89,8 @@ func TestBundleConfigureWithNonExistentProfileFlag(t *testing.T) { } func TestBundleConfigureWithMismatchedProfile(t *testing.T) { + testutil.CleanupEnvironment(t) + cmd := emptyCommand(t) cmd.Flag("profile").Value.Set("PROFILE-1") @@ -92,6 +101,8 @@ func TestBundleConfigureWithMismatchedProfile(t *testing.T) { } func TestBundleConfigureWithCorrectProfile(t *testing.T) { + testutil.CleanupEnvironment(t) + cmd := emptyCommand(t) cmd.Flag("profile").Value.Set("PROFILE-1") @@ -102,10 +113,8 @@ func TestBundleConfigureWithCorrectProfile(t *testing.T) { } func TestBundleConfigureWithMismatchedProfileEnvVariable(t *testing.T) { + testutil.CleanupEnvironment(t) t.Setenv("DATABRICKS_CONFIG_PROFILE", "PROFILE-1") - t.Cleanup(func() { - t.Setenv("DATABRICKS_CONFIG_PROFILE", "") - }) cmd := emptyCommand(t) b := setup(t, cmd, "https://x.com") @@ -115,10 +124,8 @@ func TestBundleConfigureWithMismatchedProfileEnvVariable(t *testing.T) { } func TestBundleConfigureWithProfileFlagAndEnvVariable(t *testing.T) { + testutil.CleanupEnvironment(t) t.Setenv("DATABRICKS_CONFIG_PROFILE", "NOEXIST") - t.Cleanup(func() { - t.Setenv("DATABRICKS_CONFIG_PROFILE", "") - }) cmd := emptyCommand(t) cmd.Flag("profile").Value.Set("PROFILE-1") diff --git a/cmd/root/io.go b/cmd/root/io.go index 380c01b18..23c7d6c64 100644 --- a/cmd/root/io.go +++ b/cmd/root/io.go @@ -1,9 +1,8 @@ package root import ( - "os" - "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/flags" "github.com/spf13/cobra" ) @@ -21,7 +20,7 @@ func initOutputFlag(cmd *cobra.Command) *outputFlag { // Configure defaults from environment, if applicable. // If the provided value is invalid it is ignored. - if v, ok := os.LookupEnv(envOutputFormat); ok { + if v, ok := env.Lookup(cmd.Context(), envOutputFormat); ok { f.output.Set(v) } diff --git a/cmd/root/logger.go b/cmd/root/logger.go index ddfae445a..dca07ca4b 100644 --- a/cmd/root/logger.go +++ b/cmd/root/logger.go @@ -5,9 +5,9 @@ import ( "fmt" "io" "log/slog" - "os" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/log" "github.com/fatih/color" @@ -126,13 +126,13 @@ func initLogFlags(cmd *cobra.Command) *logFlags { // Configure defaults from environment, if applicable. // If the provided value is invalid it is ignored. - if v, ok := os.LookupEnv(envLogFile); ok { + if v, ok := env.Lookup(cmd.Context(), envLogFile); ok { f.file.Set(v) } - if v, ok := os.LookupEnv(envLogLevel); ok { + if v, ok := env.Lookup(cmd.Context(), envLogLevel); ok { f.level.Set(v) } - if v, ok := os.LookupEnv(envLogFormat); ok { + if v, ok := env.Lookup(cmd.Context(), envLogFormat); ok { f.output.Set(v) } diff --git a/cmd/root/progress_logger.go b/cmd/root/progress_logger.go index bdf52558b..328b99476 100644 --- a/cmd/root/progress_logger.go +++ b/cmd/root/progress_logger.go @@ -6,6 +6,7 @@ import ( "os" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/flags" "github.com/spf13/cobra" "golang.org/x/term" @@ -51,7 +52,7 @@ func initProgressLoggerFlag(cmd *cobra.Command, logFlags *logFlags) *progressLog // Configure defaults from environment, if applicable. // If the provided value is invalid it is ignored. - if v, ok := os.LookupEnv(envProgressFormat); ok { + if v, ok := env.Lookup(cmd.Context(), envProgressFormat); ok { f.Set(v) } diff --git a/cmd/root/root.go b/cmd/root/root.go index c71cf9eac..38eb42ccb 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -14,7 +14,7 @@ import ( "github.com/spf13/cobra" ) -func New() *cobra.Command { +func New(ctx context.Context) *cobra.Command { cmd := &cobra.Command{ Use: "databricks", Short: "Databricks CLI", @@ -30,6 +30,10 @@ func New() *cobra.Command { SilenceErrors: true, } + // Pass the context along through the command during initialization. + // It will be overwritten when the command is executed. + cmd.SetContext(ctx) + // Initialize flags logFlags := initLogFlags(cmd) progressLoggerFlag := initProgressLoggerFlag(cmd, logFlags) diff --git a/cmd/root/user_agent_upstream.go b/cmd/root/user_agent_upstream.go index 3e173bda8..f580b4263 100644 --- a/cmd/root/user_agent_upstream.go +++ b/cmd/root/user_agent_upstream.go @@ -2,8 +2,8 @@ package root import ( "context" - "os" + "github.com/databricks/cli/libs/env" "github.com/databricks/databricks-sdk-go/useragent" ) @@ -16,7 +16,7 @@ const upstreamKey = "upstream" const upstreamVersionKey = "upstream-version" func withUpstreamInUserAgent(ctx context.Context) context.Context { - value := os.Getenv(upstreamEnvVar) + value := env.Get(ctx, upstreamEnvVar) if value == "" { return ctx } @@ -24,7 +24,7 @@ func withUpstreamInUserAgent(ctx context.Context) context.Context { ctx = useragent.InContext(ctx, upstreamKey, value) // Include upstream version as well, if set. - value = os.Getenv(upstreamVersionEnvVar) + value = env.Get(ctx, upstreamVersionEnvVar) if value == "" { return ctx } diff --git a/cmd/sync/sync.go b/cmd/sync/sync.go index 4a62123ba..5fdfb169d 100644 --- a/cmd/sync/sync.go +++ b/cmd/sync/sync.go @@ -30,12 +30,12 @@ func (f *syncFlags) syncOptionsFromBundle(cmd *cobra.Command, args []string, b * return nil, fmt.Errorf("SRC and DST are not configurable in the context of a bundle") } - cacheDir, err := b.CacheDir() + cacheDir, err := b.CacheDir(cmd.Context()) if err != nil { return nil, fmt.Errorf("cannot get bundle cache directory: %w", err) } - includes, err := b.GetSyncIncludePatterns() + includes, err := b.GetSyncIncludePatterns(cmd.Context()) if err != nil { return nil, fmt.Errorf("cannot get list of sync includes: %w", err) } diff --git a/cmd/version/version.go b/cmd/version/version.go index 17bb4b9af..653fbb897 100644 --- a/cmd/version/version.go +++ b/cmd/version/version.go @@ -8,9 +8,9 @@ import ( func New() *cobra.Command { cmd := &cobra.Command{ - Use: "version", - Args: cobra.NoArgs, - + Use: "version", + Args: cobra.NoArgs, + Short: "Retrieve information about the current version of this CLI", Annotations: map[string]string{ "template": "Databricks CLI v{{.Version}}\n", }, diff --git a/go.mod b/go.mod index 7e24b0db2..14c85e675 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.21 require ( github.com/briandowns/spinner v1.23.0 // Apache 2.0 - github.com/databricks/databricks-sdk-go v0.19.0 // Apache 2.0 + github.com/databricks/databricks-sdk-go v0.19.1 // Apache 2.0 github.com/fatih/color v1.15.0 // MIT github.com/ghodss/yaml v1.0.0 // MIT + NOTICE github.com/google/uuid v1.3.0 // BSD-3-Clause diff --git a/go.sum b/go.sum index 83bb01b62..20c985b07 100644 --- a/go.sum +++ b/go.sum @@ -36,8 +36,8 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= -github.com/databricks/databricks-sdk-go v0.19.0 h1:Xh5A90/+8ehW7fTqoQbQK5xZu7a/akv3Xwv8UdWB4GU= -github.com/databricks/databricks-sdk-go v0.19.0/go.mod h1:Bt/3i3ry/rQdE6Y+psvkAENlp+LzJHaQK5PsLIstQb4= +github.com/databricks/databricks-sdk-go v0.19.1 h1:hP7xZb+Hd8n0grnEcf2FOMn6lWox7vp5KAan3D2hnzM= +github.com/databricks/databricks-sdk-go v0.19.1/go.mod h1:Bt/3i3ry/rQdE6Y+psvkAENlp+LzJHaQK5PsLIstQb4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/bundle/bundles/python_wheel_task/databricks_template_schema.json b/internal/bundle/bundles/python_wheel_task/databricks_template_schema.json new file mode 100644 index 000000000..f7f4b6342 --- /dev/null +++ b/internal/bundle/bundles/python_wheel_task/databricks_template_schema.json @@ -0,0 +1,21 @@ +{ + "properties": { + "project_name": { + "type": "string", + "default": "my_test_code", + "description": "Unique name for this project" + }, + "spark_version": { + "type": "string", + "description": "Spark version used for job cluster" + }, + "node_type_id": { + "type": "string", + "description": "Node type id for job cluster" + }, + "unique_id": { + "type": "string", + "description": "Unique ID for job name" + } + } +} diff --git a/internal/bundle/bundles/python_wheel_task/template/databricks.yml.tmpl b/internal/bundle/bundles/python_wheel_task/template/databricks.yml.tmpl new file mode 100644 index 000000000..e715cdf1e --- /dev/null +++ b/internal/bundle/bundles/python_wheel_task/template/databricks.yml.tmpl @@ -0,0 +1,24 @@ +bundle: + name: wheel-task + +workspace: + root_path: "~/.bundle/{{.unique_id}}" + +resources: + jobs: + some_other_job: + name: "[${bundle.target}] Test Wheel Job {{.unique_id}}" + tasks: + - task_key: TestTask + new_cluster: + num_workers: 1 + spark_version: "{{.spark_version}}" + node_type_id: "{{.node_type_id}}" + python_wheel_task: + package_name: my_test_code + entry_point: run + parameters: + - "one" + - "two" + libraries: + - whl: ./dist/*.whl diff --git a/internal/bundle/bundles/python_wheel_task/template/setup.py.tmpl b/internal/bundle/bundles/python_wheel_task/template/setup.py.tmpl new file mode 100644 index 000000000..b528657b1 --- /dev/null +++ b/internal/bundle/bundles/python_wheel_task/template/setup.py.tmpl @@ -0,0 +1,15 @@ +from setuptools import setup, find_packages + +import {{.project_name}} + +setup( + name="{{.project_name}}", + version={{.project_name}}.__version__, + author={{.project_name}}.__author__, + url="https://databricks.com", + author_email="john.doe@databricks.com", + description="my example wheel", + packages=find_packages(include=["{{.project_name}}"]), + entry_points={"group1": "run={{.project_name}}.__main__:main"}, + install_requires=["setuptools"], +) diff --git a/internal/bundle/bundles/python_wheel_task/template/{{.project_name}}/__init__.py b/internal/bundle/bundles/python_wheel_task/template/{{.project_name}}/__init__.py new file mode 100644 index 000000000..909f1f322 --- /dev/null +++ b/internal/bundle/bundles/python_wheel_task/template/{{.project_name}}/__init__.py @@ -0,0 +1,2 @@ +__version__ = "0.0.1" +__author__ = "Databricks" diff --git a/internal/bundle/bundles/python_wheel_task/template/{{.project_name}}/__main__.py b/internal/bundle/bundles/python_wheel_task/template/{{.project_name}}/__main__.py new file mode 100644 index 000000000..ea918ce2d --- /dev/null +++ b/internal/bundle/bundles/python_wheel_task/template/{{.project_name}}/__main__.py @@ -0,0 +1,16 @@ +""" +The entry point of the Python Wheel +""" + +import sys + + +def main(): + # This method will print the provided arguments + print("Hello from my func") + print("Got arguments:") + print(sys.argv) + + +if __name__ == "__main__": + main() diff --git a/internal/bundle/helpers.go b/internal/bundle/helpers.go new file mode 100644 index 000000000..3fd4eabc9 --- /dev/null +++ b/internal/bundle/helpers.go @@ -0,0 +1,70 @@ +package bundle + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/internal" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/flags" + "github.com/databricks/cli/libs/template" +) + +func initTestTemplate(t *testing.T, templateName string, config map[string]any) (string, error) { + templateRoot := filepath.Join("bundles", templateName) + + bundleRoot := t.TempDir() + configFilePath, err := writeConfigFile(t, config) + if err != nil { + return "", err + } + + ctx := root.SetWorkspaceClient(context.Background(), nil) + cmd := cmdio.NewIO(flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "bundles") + ctx = cmdio.InContext(ctx, cmd) + + err = template.Materialize(ctx, configFilePath, templateRoot, bundleRoot) + return bundleRoot, err +} + +func writeConfigFile(t *testing.T, config map[string]any) (string, error) { + bytes, err := json.Marshal(config) + if err != nil { + return "", err + } + + dir := t.TempDir() + filepath := filepath.Join(dir, "config.json") + t.Log("Configuration for template: ", string(bytes)) + + err = os.WriteFile(filepath, bytes, 0644) + return filepath, err +} + +func deployBundle(t *testing.T, path string) error { + t.Setenv("BUNDLE_ROOT", path) + c := internal.NewCobraTestRunner(t, "bundle", "deploy", "--force-lock") + _, _, err := c.Run() + return err +} + +func runResource(t *testing.T, path string, key string) (string, error) { + ctx := context.Background() + ctx = cmdio.NewContext(ctx, cmdio.Default()) + + c := internal.NewCobraTestRunnerWithContext(t, ctx, "bundle", "run", key) + stdout, _, err := c.Run() + return stdout.String(), err +} + +func destroyBundle(t *testing.T, path string) error { + t.Setenv("BUNDLE_ROOT", path) + c := internal.NewCobraTestRunner(t, "bundle", "destroy", "--auto-approve") + _, _, err := c.Run() + return err +} diff --git a/internal/bundle/python_wheel_test.go b/internal/bundle/python_wheel_test.go new file mode 100644 index 000000000..ee5d897d6 --- /dev/null +++ b/internal/bundle/python_wheel_test.go @@ -0,0 +1,43 @@ +package bundle + +import ( + "testing" + + "github.com/databricks/cli/internal" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func TestAccPythonWheelTaskDeployAndRun(t *testing.T) { + env := internal.GetEnvOrSkipTest(t, "CLOUD_ENV") + t.Log(env) + + var nodeTypeId string + if env == "gcp" { + nodeTypeId = "n1-standard-4" + } else if env == "aws" { + nodeTypeId = "i3.xlarge" + } else { + nodeTypeId = "Standard_DS4_v2" + } + + bundleRoot, err := initTestTemplate(t, "python_wheel_task", map[string]any{ + "node_type_id": nodeTypeId, + "unique_id": uuid.New().String(), + "spark_version": "13.2.x-snapshot-scala2.12", + }) + require.NoError(t, err) + + err = deployBundle(t, bundleRoot) + require.NoError(t, err) + + t.Cleanup(func() { + destroyBundle(t, bundleRoot) + }) + + out, err := runResource(t, bundleRoot, "some_other_job") + require.NoError(t, err) + require.Contains(t, out, "Hello from my func") + require.Contains(t, out, "Got arguments:") + require.Contains(t, out, "['python', 'one', 'two']") +} diff --git a/internal/helpers.go b/internal/helpers.go index ddc005173..68c00019a 100644 --- a/internal/helpers.go +++ b/internal/helpers.go @@ -58,6 +58,8 @@ type cobraTestRunner struct { stdout bytes.Buffer stderr bytes.Buffer + ctx context.Context + // Line-by-line output. // Background goroutines populate these channels by reading from stdout/stderr pipes. stdoutLines <-chan string @@ -116,7 +118,7 @@ func (t *cobraTestRunner) RunBackground() { var stdoutW, stderrW io.WriteCloser stdoutR, stdoutW = io.Pipe() stderrR, stderrW = io.Pipe() - root := cmd.New() + root := cmd.New(context.Background()) root.SetOut(stdoutW) root.SetErr(stderrW) root.SetArgs(t.args) @@ -128,7 +130,7 @@ func (t *cobraTestRunner) RunBackground() { t.registerFlagCleanup(root) errch := make(chan error) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.ctx) // Tee stdout/stderr to buffers. stdoutR = io.TeeReader(stdoutR, &t.stdout) @@ -234,6 +236,15 @@ func (c *cobraTestRunner) Eventually(condition func() bool, waitFor time.Duratio func NewCobraTestRunner(t *testing.T, args ...string) *cobraTestRunner { return &cobraTestRunner{ T: t, + ctx: context.Background(), + args: args, + } +} + +func NewCobraTestRunnerWithContext(t *testing.T, ctx context.Context, args ...string) *cobraTestRunner { + return &cobraTestRunner{ + T: t, + ctx: ctx, args: args, } } diff --git a/internal/testutil/env.go b/internal/testutil/env.go new file mode 100644 index 000000000..11a610189 --- /dev/null +++ b/internal/testutil/env.go @@ -0,0 +1,37 @@ +package testutil + +import ( + "os" + "runtime" + "strings" + "testing" +) + +// CleanupEnvironment sets up a pristine environment containing only $PATH and $HOME. +// The original environment is restored upon test completion. +// Note: use of this function is incompatible with parallel execution. +func CleanupEnvironment(t *testing.T) { + // Restore environment when test finishes. + environ := os.Environ() + t.Cleanup(func() { + // Restore original environment. + for _, kv := range environ { + kvs := strings.SplitN(kv, "=", 2) + os.Setenv(kvs[0], kvs[1]) + } + }) + + path := os.Getenv("PATH") + pwd := os.Getenv("PWD") + os.Clearenv() + + // We use t.Setenv instead of os.Setenv because the former actively + // prevents a test being run with t.Parallel. Modifying the environment + // within a test is not compatible with running tests in parallel + // because of isolation; the environment is scoped to the process. + t.Setenv("PATH", path) + t.Setenv("HOME", pwd) + if runtime.GOOS == "windows" { + t.Setenv("USERPROFILE", pwd) + } +} diff --git a/libs/cmdio/io.go b/libs/cmdio/io.go index 9d712e351..cf405a7a4 100644 --- a/libs/cmdio/io.go +++ b/libs/cmdio/io.go @@ -205,6 +205,13 @@ func Prompt(ctx context.Context) *promptui.Prompt { } } +func RunSelect(ctx context.Context, prompt *promptui.Select) (int, string, error) { + c := fromContext(ctx) + prompt.Stdin = io.NopCloser(c.in) + prompt.Stdout = nopWriteCloser{c.err} + return prompt.Run() +} + func (c *cmdIO) simplePrompt(label string) *promptui.Prompt { return &promptui.Prompt{ Label: label, diff --git a/libs/cmdio/logger.go b/libs/cmdio/logger.go index 0663306e1..7d760b998 100644 --- a/libs/cmdio/logger.go +++ b/libs/cmdio/logger.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/databricks/cli/libs/flags" + "github.com/manifoldco/promptui" ) // This is the interface for all io interactions with a user @@ -104,6 +105,36 @@ func AskYesOrNo(ctx context.Context, question string) (bool, error) { return false, nil } +func AskSelect(ctx context.Context, question string, choices []string) (string, error) { + logger, ok := FromContext(ctx) + if !ok { + logger = Default() + } + return logger.AskSelect(question, choices) +} + +func (l *Logger) AskSelect(question string, choices []string) (string, error) { + if l.Mode == flags.ModeJson { + return "", fmt.Errorf("question prompts are not supported in json mode") + } + + prompt := promptui.Select{ + Label: question, + Items: choices, + HideHelp: true, + Templates: &promptui.SelectTemplates{ + Label: "{{.}}: ", + Selected: fmt.Sprintf("%s: {{.}}", question), + }, + } + + _, ans, err := prompt.Run() + if err != nil { + return "", err + } + return ans, nil +} + func (l *Logger) Ask(question string, defaultVal string) (string, error) { if l.Mode == flags.ModeJson { return "", fmt.Errorf("question prompts are not supported in json mode") diff --git a/libs/cmdio/logger_test.go b/libs/cmdio/logger_test.go index da6190462..c5c00d022 100644 --- a/libs/cmdio/logger_test.go +++ b/libs/cmdio/logger_test.go @@ -1,6 +1,7 @@ package cmdio import ( + "context" "testing" "github.com/databricks/cli/libs/flags" @@ -12,3 +13,11 @@ func TestAskFailedInJsonMode(t *testing.T) { _, err := l.Ask("What is your spirit animal?", "") assert.ErrorContains(t, err, "question prompts are not supported in json mode") } + +func TestAskChoiceFailsInJsonMode(t *testing.T) { + l := NewLogger(flags.ModeJson) + ctx := NewContext(context.Background(), l) + + _, err := AskSelect(ctx, "what is a question?", []string{"b", "c", "a"}) + assert.EqualError(t, err, "question prompts are not supported in json mode") +} diff --git a/libs/cmdio/testing.go b/libs/cmdio/testing.go new file mode 100644 index 000000000..43592489e --- /dev/null +++ b/libs/cmdio/testing.go @@ -0,0 +1,46 @@ +package cmdio + +import ( + "bufio" + "context" + "io" +) + +type Test struct { + Done context.CancelFunc + + Stdin *bufio.Writer + Stdout *bufio.Reader + Stderr *bufio.Reader +} + +func SetupTest(ctx context.Context) (context.Context, *Test) { + rin, win := io.Pipe() + rout, wout := io.Pipe() + rerr, werr := io.Pipe() + + cmdio := &cmdIO{ + interactive: true, + in: rin, + out: wout, + err: werr, + } + + ctx, cancel := context.WithCancel(ctx) + ctx = InContext(ctx, cmdio) + + // Wait for context to be done, so we can drain stdin and close the pipes. + go func() { + <-ctx.Done() + rin.Close() + wout.Close() + werr.Close() + }() + + return ctx, &Test{ + Done: cancel, + Stdin: bufio.NewWriter(win), + Stdout: bufio.NewReader(rout), + Stderr: bufio.NewReader(rerr), + } +} diff --git a/libs/env/context.go b/libs/env/context.go new file mode 100644 index 000000000..cf04c1ece --- /dev/null +++ b/libs/env/context.go @@ -0,0 +1,63 @@ +package env + +import ( + "context" + "os" +) + +var envContextKey int + +func copyMap(m map[string]string) map[string]string { + out := make(map[string]string, len(m)) + for k, v := range m { + out[k] = v + } + return out +} + +func getMap(ctx context.Context) map[string]string { + if ctx == nil { + return nil + } + m, ok := ctx.Value(&envContextKey).(map[string]string) + if !ok { + return nil + } + return m +} + +func setMap(ctx context.Context, m map[string]string) context.Context { + return context.WithValue(ctx, &envContextKey, m) +} + +// Lookup key in the context or the the environment. +// Context has precedence. +func Lookup(ctx context.Context, key string) (string, bool) { + m := getMap(ctx) + + // Return if the key is set in the context. + v, ok := m[key] + if ok { + return v, true + } + + // Fall back to the environment. + return os.LookupEnv(key) +} + +// Get key from the context or the environment. +// Context has precedence. +func Get(ctx context.Context, key string) string { + v, _ := Lookup(ctx, key) + return v +} + +// Set key on the context. +// +// Note: this does NOT mutate the processes' actual environment variables. +// It is only visible to other code that uses this package. +func Set(ctx context.Context, key, value string) context.Context { + m := copyMap(getMap(ctx)) + m[key] = value + return setMap(ctx, m) +} diff --git a/libs/env/context_test.go b/libs/env/context_test.go new file mode 100644 index 000000000..9ff194597 --- /dev/null +++ b/libs/env/context_test.go @@ -0,0 +1,41 @@ +package env + +import ( + "context" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/stretchr/testify/assert" +) + +func TestContext(t *testing.T) { + testutil.CleanupEnvironment(t) + t.Setenv("FOO", "bar") + + ctx0 := context.Background() + + // Get + assert.Equal(t, "bar", Get(ctx0, "FOO")) + assert.Equal(t, "", Get(ctx0, "dontexist")) + + // Lookup + v, ok := Lookup(ctx0, "FOO") + assert.True(t, ok) + assert.Equal(t, "bar", v) + v, ok = Lookup(ctx0, "dontexist") + assert.False(t, ok) + assert.Equal(t, "", v) + + // Set and get new context. + // Verify that the previous context remains unchanged. + ctx1 := Set(ctx0, "FOO", "baz") + assert.Equal(t, "baz", Get(ctx1, "FOO")) + assert.Equal(t, "bar", Get(ctx0, "FOO")) + + // Set and get new context. + // Verify that the previous contexts remains unchanged. + ctx2 := Set(ctx1, "FOO", "qux") + assert.Equal(t, "qux", Get(ctx2, "FOO")) + assert.Equal(t, "baz", Get(ctx1, "FOO")) + assert.Equal(t, "bar", Get(ctx0, "FOO")) +} diff --git a/libs/env/pkg.go b/libs/env/pkg.go new file mode 100644 index 000000000..e0be7e225 --- /dev/null +++ b/libs/env/pkg.go @@ -0,0 +1,7 @@ +package env + +// The env package provides functions for working with environment variables +// and allowing for overrides via the context.Context. This is useful for +// testing where tainting a processes' environment is at odds with parallelism. +// Use of a context.Context to store variable overrides means tests can be +// parallelized without worrying about environment variable interference. diff --git a/libs/jsonschema/instance.go b/libs/jsonschema/instance.go new file mode 100644 index 000000000..229a45b53 --- /dev/null +++ b/libs/jsonschema/instance.go @@ -0,0 +1,113 @@ +package jsonschema + +import ( + "encoding/json" + "fmt" + "os" + "slices" +) + +// Load a JSON document and validate it against the JSON schema. Instance here +// refers to a JSON document. see: https://json-schema.org/draft/2020-12/json-schema-core.html#name-instance +func (s *Schema) LoadInstance(path string) (map[string]any, error) { + instance := make(map[string]any) + b, err := os.ReadFile(path) + if err != nil { + return nil, err + } + err = json.Unmarshal(b, &instance) + if err != nil { + return nil, err + } + + // The default JSON unmarshaler parses untyped number values as float64. + // We convert integer properties from float64 to int64 here. + for name, v := range instance { + propertySchema, ok := s.Properties[name] + if !ok { + continue + } + if propertySchema.Type != IntegerType { + continue + } + integerValue, err := toInteger(v) + if err != nil { + return nil, fmt.Errorf("failed to parse property %s: %w", name, err) + } + instance[name] = integerValue + } + return instance, s.ValidateInstance(instance) +} + +func (s *Schema) ValidateInstance(instance map[string]any) error { + for _, fn := range []func(map[string]any) error{ + s.validateAdditionalProperties, + s.validateEnum, + s.validateRequired, + s.validateTypes, + } { + err := fn(instance) + if err != nil { + return err + } + } + return nil +} + +// If additional properties is set to false, this function validates instance only +// contains properties defined in the schema. +func (s *Schema) validateAdditionalProperties(instance map[string]any) error { + // Note: AdditionalProperties has the type any. + if s.AdditionalProperties != false { + return nil + } + for k := range instance { + _, ok := s.Properties[k] + if !ok { + return fmt.Errorf("property %s is not defined in the schema", k) + } + } + return nil +} + +// This function validates that all require properties in the schema have values +// in the instance. +func (s *Schema) validateRequired(instance map[string]any) error { + for _, name := range s.Required { + if _, ok := instance[name]; !ok { + return fmt.Errorf("no value provided for required property %s", name) + } + } + return nil +} + +// Validates the types of all input properties values match their types defined in the schema +func (s *Schema) validateTypes(instance map[string]any) error { + for k, v := range instance { + fieldInfo, ok := s.Properties[k] + if !ok { + continue + } + err := validateType(v, fieldInfo.Type) + if err != nil { + return fmt.Errorf("incorrect type for property %s: %w", k, err) + } + } + return nil +} + +func (s *Schema) validateEnum(instance map[string]any) error { + for k, v := range instance { + fieldInfo, ok := s.Properties[k] + if !ok { + continue + } + if fieldInfo.Enum == nil { + continue + } + if !slices.Contains(fieldInfo.Enum, v) { + return fmt.Errorf("expected value of property %s to be one of %v. Found: %v", k, fieldInfo.Enum, v) + } + } + return nil +} diff --git a/libs/jsonschema/instance_test.go b/libs/jsonschema/instance_test.go new file mode 100644 index 000000000..ffd10ca43 --- /dev/null +++ b/libs/jsonschema/instance_test.go @@ -0,0 +1,155 @@ +package jsonschema + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidateInstanceAdditionalPropertiesPermitted(t *testing.T) { + instance := map[string]any{ + "int_val": 1, + "float_val": 1.0, + "bool_val": false, + "an_additional_property": "abc", + } + + schema, err := Load("./testdata/instance-validate/test-schema.json") + require.NoError(t, err) + + err = schema.validateAdditionalProperties(instance) + assert.NoError(t, err) + + err = schema.ValidateInstance(instance) + assert.NoError(t, err) +} + +func TestValidateInstanceAdditionalPropertiesForbidden(t *testing.T) { + instance := map[string]any{ + "int_val": 1, + "float_val": 1.0, + "bool_val": false, + "an_additional_property": "abc", + } + + schema, err := Load("./testdata/instance-validate/test-schema-no-additional-properties.json") + require.NoError(t, err) + + err = schema.validateAdditionalProperties(instance) + assert.EqualError(t, err, "property an_additional_property is not defined in the schema") + + err = schema.ValidateInstance(instance) + assert.EqualError(t, err, "property an_additional_property is not defined in the schema") + + instanceWOAdditionalProperties := map[string]any{ + "int_val": 1, + "float_val": 1.0, + "bool_val": false, + } + + err = schema.validateAdditionalProperties(instanceWOAdditionalProperties) + assert.NoError(t, err) + + err = schema.ValidateInstance(instanceWOAdditionalProperties) + assert.NoError(t, err) +} + +func TestValidateInstanceTypes(t *testing.T) { + schema, err := Load("./testdata/instance-validate/test-schema.json") + require.NoError(t, err) + + validInstance := map[string]any{ + "int_val": 1, + "float_val": 1.0, + "bool_val": false, + } + + err = schema.validateTypes(validInstance) + assert.NoError(t, err) + + err = schema.ValidateInstance(validInstance) + assert.NoError(t, err) + + invalidInstance := map[string]any{ + "int_val": "abc", + "float_val": 1.0, + "bool_val": false, + } + + err = schema.validateTypes(invalidInstance) + assert.EqualError(t, err, "incorrect type for property int_val: expected type integer, but value is \"abc\"") + + err = schema.ValidateInstance(invalidInstance) + assert.EqualError(t, err, "incorrect type for property int_val: expected type integer, but value is \"abc\"") +} + +func TestValidateInstanceRequired(t *testing.T) { + schema, err := Load("./testdata/instance-validate/test-schema-some-fields-required.json") + require.NoError(t, err) + + validInstance := map[string]any{ + "int_val": 1, + "float_val": 1.0, + "bool_val": false, + } + err = schema.validateRequired(validInstance) + assert.NoError(t, err) + err = schema.ValidateInstance(validInstance) + assert.NoError(t, err) + + invalidInstance := map[string]any{ + "string_val": "abc", + "float_val": 1.0, + "bool_val": false, + } + err = schema.validateRequired(invalidInstance) + assert.EqualError(t, err, "no value provided for required property int_val") + err = schema.ValidateInstance(invalidInstance) + assert.EqualError(t, err, "no value provided for required property int_val") +} + +func TestLoadInstance(t *testing.T) { + schema, err := Load("./testdata/instance-validate/test-schema.json") + require.NoError(t, err) + + // Expect the instance to be loaded successfully. + instance, err := schema.LoadInstance("./testdata/instance-load/valid-instance.json") + assert.NoError(t, err) + assert.Equal(t, map[string]any{ + "bool_val": false, + "int_val": int64(1), + "string_val": "abc", + "float_val": 2.0, + }, instance) + + // Expect instance validation against the schema to fail. + _, err = schema.LoadInstance("./testdata/instance-load/invalid-type-instance.json") + assert.EqualError(t, err, "incorrect type for property string_val: expected type string, but value is 123") +} + +func TestValidateInstanceEnum(t *testing.T) { + schema, err := Load("./testdata/instance-validate/test-schema-enum.json") + require.NoError(t, err) + + validInstance := map[string]any{ + "foo": "b", + "bar": int64(6), + } + assert.NoError(t, schema.validateEnum(validInstance)) + assert.NoError(t, schema.ValidateInstance(validInstance)) + + invalidStringInstance := map[string]any{ + "foo": "d", + "bar": int64(2), + } + assert.EqualError(t, schema.validateEnum(invalidStringInstance), "expected value of property foo to be one of [a b c]. Found: d") + assert.EqualError(t, schema.ValidateInstance(invalidStringInstance), "expected value of property foo to be one of [a b c]. Found: d") + + invalidIntInstance := map[string]any{ + "foo": "a", + "bar": int64(1), + } + assert.EqualError(t, schema.validateEnum(invalidIntInstance), "expected value of property bar to be one of [2 4 6]. Found: 1") + assert.EqualError(t, schema.ValidateInstance(invalidIntInstance), "expected value of property bar to be one of [2 4 6]. Found: 1") +} diff --git a/libs/jsonschema/schema.go b/libs/jsonschema/schema.go index 87e9acd56..108102a64 100644 --- a/libs/jsonschema/schema.go +++ b/libs/jsonschema/schema.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "os" + "slices" ) // defines schema for a json object @@ -41,6 +42,9 @@ type Schema struct { // Default value for the property / object Default any `json:"default,omitempty"` + // List of valid values for a JSON instance for this schema. + Enum []any `json:"enum,omitempty"` + // Extension embeds our custom JSON schema extensions. Extension } @@ -58,6 +62,7 @@ const ( ) func (schema *Schema) validate() error { + // Validate property types are all valid JSON schema types. for _, v := range schema.Properties { switch v.Type { case NumberType, BooleanType, StringType, IntegerType: @@ -72,6 +77,41 @@ func (schema *Schema) validate() error { return fmt.Errorf("type %s is not a recognized json schema type", v.Type) } } + + // Validate default property values are consistent with types. + for name, property := range schema.Properties { + if property.Default == nil { + continue + } + if err := validateType(property.Default, property.Type); err != nil { + return fmt.Errorf("type validation for default value of property %s failed: %w", name, err) + } + } + + // Validate enum field values for properties are consistent with types. + for name, property := range schema.Properties { + if property.Enum == nil { + continue + } + for i, enum := range property.Enum { + err := validateType(enum, property.Type) + if err != nil { + return fmt.Errorf("type validation for enum at index %v failed for property %s: %w", i, name, err) + } + } + } + + // Validate default value is contained in the list of enums if both are defined. + for name, property := range schema.Properties { + if property.Default == nil || property.Enum == nil { + continue + } + // We expect the default value to be consistent with the list of enum + // values. + if !slices.Contains(property.Enum, property.Default) { + return fmt.Errorf("list of enum values for property %s does not contain default value %v: %v", name, property.Default, property.Enum) + } + } return nil } @@ -85,5 +125,31 @@ func Load(path string) (*Schema, error) { if err != nil { return nil, err } + + // Convert the default values of top-level properties to integers. + // This is required because the default JSON unmarshaler parses numbers + // as floats when the Golang field it's being loaded to is untyped. + // + // NOTE: properties can be recursively defined in a schema, but the current + // use-cases only uses the first layer of properties so we skip converting + // any recursive properties. + for name, property := range schema.Properties { + if property.Type != IntegerType { + continue + } + if property.Default != nil { + property.Default, err = toInteger(property.Default) + if err != nil { + return nil, fmt.Errorf("failed to parse default value for property %s: %w", name, err) + } + } + for i, enum := range property.Enum { + property.Enum[i], err = toInteger(enum) + if err != nil { + return nil, fmt.Errorf("failed to parse enum value %v at index %v for property %s: %w", enum, i, name, err) + } + } + } + return schema, schema.validate() } diff --git a/libs/jsonschema/schema_test.go b/libs/jsonschema/schema_test.go index 76112492f..db559ea88 100644 --- a/libs/jsonschema/schema_test.go +++ b/libs/jsonschema/schema_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestJsonSchemaValidate(t *testing.T) { +func TestSchemaValidateTypeNames(t *testing.T) { var err error toSchema := func(s string) *Schema { return &Schema{ @@ -42,3 +42,100 @@ func TestJsonSchemaValidate(t *testing.T) { err = toSchema("foobar").validate() assert.EqualError(t, err, "type foobar is not a recognized json schema type") } + +func TestSchemaLoadIntegers(t *testing.T) { + schema, err := Load("./testdata/schema-load-int/schema-valid.json") + assert.NoError(t, err) + assert.Equal(t, int64(1), schema.Properties["abc"].Default) + assert.Equal(t, []any{int64(1), int64(2), int64(3)}, schema.Properties["abc"].Enum) +} + +func TestSchemaLoadIntegersWithInvalidDefault(t *testing.T) { + _, err := Load("./testdata/schema-load-int/schema-invalid-default.json") + assert.EqualError(t, err, "failed to parse default value for property abc: expected integer value, got: 1.1") +} + +func TestSchemaLoadIntegersWithInvalidEnums(t *testing.T) { + _, err := Load("./testdata/schema-load-int/schema-invalid-enum.json") + assert.EqualError(t, err, "failed to parse enum value 2.4 at index 1 for property abc: expected integer value, got: 2.4") +} + +func TestSchemaValidateDefaultType(t *testing.T) { + invalidSchema := &Schema{ + Properties: map[string]*Schema{ + "foo": { + Type: "number", + Default: "abc", + }, + }, + } + + err := invalidSchema.validate() + assert.EqualError(t, err, "type validation for default value of property foo failed: expected type float, but value is \"abc\"") + + validSchema := &Schema{ + Properties: map[string]*Schema{ + "foo": { + Type: "boolean", + Default: true, + }, + }, + } + + err = validSchema.validate() + assert.NoError(t, err) +} + +func TestSchemaValidateEnumType(t *testing.T) { + invalidSchema := &Schema{ + Properties: map[string]*Schema{ + "foo": { + Type: "boolean", + Enum: []any{true, "false"}, + }, + }, + } + + err := invalidSchema.validate() + assert.EqualError(t, err, "type validation for enum at index 1 failed for property foo: expected type boolean, but value is \"false\"") + + validSchema := &Schema{ + Properties: map[string]*Schema{ + "foo": { + Type: "boolean", + Enum: []any{true, false}, + }, + }, + } + + err = validSchema.validate() + assert.NoError(t, err) +} + +func TestSchemaValidateErrorWhenDefaultValueIsNotInEnums(t *testing.T) { + invalidSchema := &Schema{ + Properties: map[string]*Schema{ + "foo": { + Type: "string", + Default: "abc", + Enum: []any{"def", "ghi"}, + }, + }, + } + + err := invalidSchema.validate() + assert.EqualError(t, err, "list of enum values for property foo does not contain default value abc: [def ghi]") + + validSchema := &Schema{ + Properties: map[string]*Schema{ + "foo": { + Type: "string", + Default: "abc", + Enum: []any{"def", "ghi", "abc"}, + }, + }, + } + + err = validSchema.validate() + assert.NoError(t, err) +} diff --git a/libs/jsonschema/testdata/instance-load/invalid-type-instance.json b/libs/jsonschema/testdata/instance-load/invalid-type-instance.json new file mode 100644 index 000000000..c55b6fccb --- /dev/null +++ b/libs/jsonschema/testdata/instance-load/invalid-type-instance.json @@ -0,0 +1,6 @@ +{ + "int_val": 1, + "bool_val": false, + "string_val": 123, + "float_val": 3.0 +} diff --git a/libs/jsonschema/testdata/instance-load/valid-instance.json b/libs/jsonschema/testdata/instance-load/valid-instance.json new file mode 100644 index 000000000..7d4dc818a --- /dev/null +++ b/libs/jsonschema/testdata/instance-load/valid-instance.json @@ -0,0 +1,6 @@ +{ + "int_val": 1, + "bool_val": false, + "string_val": "abc", + "float_val": 2.0 +} diff --git a/libs/jsonschema/testdata/instance-validate/test-schema-enum.json b/libs/jsonschema/testdata/instance-validate/test-schema-enum.json new file mode 100644 index 000000000..75ffd6eb8 --- /dev/null +++ b/libs/jsonschema/testdata/instance-validate/test-schema-enum.json @@ -0,0 +1,12 @@ +{ + "properties": { + "foo": { + "type": "string", + "enum": ["a", "b", "c"] + }, + "bar": { + "type": "integer", + "enum": [2,4,6] + } + } +} diff --git a/libs/jsonschema/testdata/instance-validate/test-schema-no-additional-properties.json b/libs/jsonschema/testdata/instance-validate/test-schema-no-additional-properties.json new file mode 100644 index 000000000..98b19d5a4 --- /dev/null +++ b/libs/jsonschema/testdata/instance-validate/test-schema-no-additional-properties.json @@ -0,0 +1,19 @@ +{ + "properties": { + "int_val": { + "type": "integer", + "default": 123 + }, + "float_val": { + "type": "number" + }, + "bool_val": { + "type": "boolean" + }, + "string_val": { + "type": "string", + "default": "abc" + } + }, + "additionalProperties": false +} diff --git a/libs/jsonschema/testdata/instance-validate/test-schema-some-fields-required.json b/libs/jsonschema/testdata/instance-validate/test-schema-some-fields-required.json new file mode 100644 index 000000000..465811034 --- /dev/null +++ b/libs/jsonschema/testdata/instance-validate/test-schema-some-fields-required.json @@ -0,0 +1,19 @@ +{ + "properties": { + "int_val": { + "type": "integer", + "default": 123 + }, + "float_val": { + "type": "number" + }, + "bool_val": { + "type": "boolean" + }, + "string_val": { + "type": "string", + "default": "abc" + } + }, + "required": ["int_val", "float_val", "bool_val"] +} diff --git a/libs/jsonschema/testdata/instance-validate/test-schema.json b/libs/jsonschema/testdata/instance-validate/test-schema.json new file mode 100644 index 000000000..41eb82519 --- /dev/null +++ b/libs/jsonschema/testdata/instance-validate/test-schema.json @@ -0,0 +1,18 @@ +{ + "properties": { + "int_val": { + "type": "integer", + "default": 123 + }, + "float_val": { + "type": "number" + }, + "bool_val": { + "type": "boolean" + }, + "string_val": { + "type": "string", + "default": "abc" + } + } +} diff --git a/libs/jsonschema/testdata/schema-load-int/schema-invalid-default.json b/libs/jsonschema/testdata/schema-load-int/schema-invalid-default.json new file mode 100644 index 000000000..1e709f622 --- /dev/null +++ b/libs/jsonschema/testdata/schema-load-int/schema-invalid-default.json @@ -0,0 +1,9 @@ +{ + "type": "object", + "properties": { + "abc": { + "type": "integer", + "default": 1.1 + } + } +} diff --git a/libs/jsonschema/testdata/schema-load-int/schema-invalid-enum.json b/libs/jsonschema/testdata/schema-load-int/schema-invalid-enum.json new file mode 100644 index 000000000..5bd2b3f2b --- /dev/null +++ b/libs/jsonschema/testdata/schema-load-int/schema-invalid-enum.json @@ -0,0 +1,10 @@ +{ + "type": "object", + "properties": { + "abc": { + "type": "integer", + "default": 1, + "enum": [1,2.4,3] + } + } +} diff --git a/libs/jsonschema/testdata/schema-load-int/schema-valid.json b/libs/jsonschema/testdata/schema-load-int/schema-valid.json new file mode 100644 index 000000000..a1167a6c9 --- /dev/null +++ b/libs/jsonschema/testdata/schema-load-int/schema-valid.json @@ -0,0 +1,10 @@ +{ + "type": "object", + "properties": { + "abc": { + "type": "integer", + "default": 1, + "enum": [1,2,3] + } + } +} diff --git a/libs/template/utils.go b/libs/jsonschema/utils.go similarity index 79% rename from libs/template/utils.go rename to libs/jsonschema/utils.go index ade6a5730..66db9603e 100644 --- a/libs/template/utils.go +++ b/libs/jsonschema/utils.go @@ -1,11 +1,9 @@ -package template +package jsonschema import ( "errors" "fmt" "strconv" - - "github.com/databricks/cli/libs/jsonschema" ) // function to check whether a float value represents an integer @@ -40,41 +38,53 @@ func toInteger(v any) (int64, error) { } } -func toString(v any, T jsonschema.Type) (string, error) { +func ToString(v any, T Type) (string, error) { switch T { - case jsonschema.BooleanType: + case BooleanType: boolVal, ok := v.(bool) if !ok { return "", fmt.Errorf("expected bool, got: %#v", v) } return strconv.FormatBool(boolVal), nil - case jsonschema.StringType: + case StringType: strVal, ok := v.(string) if !ok { return "", fmt.Errorf("expected string, got: %#v", v) } return strVal, nil - case jsonschema.NumberType: + case NumberType: floatVal, ok := v.(float64) if !ok { return "", fmt.Errorf("expected float, got: %#v", v) } return strconv.FormatFloat(floatVal, 'f', -1, 64), nil - case jsonschema.IntegerType: + case IntegerType: intVal, err := toInteger(v) if err != nil { return "", err } return strconv.FormatInt(intVal, 10), nil - case jsonschema.ArrayType, jsonschema.ObjectType: + case ArrayType, ObjectType: return "", fmt.Errorf("cannot format object of type %s as a string. Value of object: %#v", T, v) default: return "", fmt.Errorf("unknown json schema type: %q", T) } } -func fromString(s string, T jsonschema.Type) (any, error) { - if T == jsonschema.StringType { +func ToStringSlice(arr []any, T Type) ([]string, error) { + res := []string{} + for _, v := range arr { + s, err := ToString(v, T) + if err != nil { + return nil, err + } + res = append(res, s) + } + return res, nil +} + +func FromString(s string, T Type) (any, error) { + if T == StringType { return s, nil } @@ -83,13 +93,13 @@ func fromString(s string, T jsonschema.Type) (any, error) { var err error switch T { - case jsonschema.BooleanType: + case BooleanType: v, err = strconv.ParseBool(s) - case jsonschema.NumberType: + case NumberType: v, err = strconv.ParseFloat(s, 32) - case jsonschema.IntegerType: + case IntegerType: v, err = strconv.ParseInt(s, 10, 64) - case jsonschema.ArrayType, jsonschema.ObjectType: + case ArrayType, ObjectType: return "", fmt.Errorf("cannot parse string as object of type %s. Value of string: %q", T, s) default: return "", fmt.Errorf("unknown json schema type: %q", T) diff --git a/libs/template/utils_test.go b/libs/jsonschema/utils_test.go similarity index 70% rename from libs/template/utils_test.go rename to libs/jsonschema/utils_test.go index 1e038aac6..29529aaa9 100644 --- a/libs/template/utils_test.go +++ b/libs/jsonschema/utils_test.go @@ -1,10 +1,9 @@ -package template +package jsonschema import ( "math" "testing" - "github.com/databricks/cli/libs/jsonschema" "github.com/stretchr/testify/assert" ) @@ -50,72 +49,82 @@ func TestTemplateToInteger(t *testing.T) { } func TestTemplateToString(t *testing.T) { - s, err := toString(true, jsonschema.BooleanType) + s, err := ToString(true, BooleanType) assert.NoError(t, err) assert.Equal(t, "true", s) - s, err = toString("abc", jsonschema.StringType) + s, err = ToString("abc", StringType) assert.NoError(t, err) assert.Equal(t, "abc", s) - s, err = toString(1.1, jsonschema.NumberType) + s, err = ToString(1.1, NumberType) assert.NoError(t, err) assert.Equal(t, "1.1", s) - s, err = toString(2, jsonschema.IntegerType) + s, err = ToString(2, IntegerType) assert.NoError(t, err) assert.Equal(t, "2", s) - _, err = toString([]string{}, jsonschema.ArrayType) + _, err = ToString([]string{}, ArrayType) assert.EqualError(t, err, "cannot format object of type array as a string. Value of object: []string{}") - _, err = toString("true", jsonschema.BooleanType) + _, err = ToString("true", BooleanType) assert.EqualError(t, err, "expected bool, got: \"true\"") - _, err = toString(123, jsonschema.StringType) + _, err = ToString(123, StringType) assert.EqualError(t, err, "expected string, got: 123") - _, err = toString(false, jsonschema.NumberType) + _, err = ToString(false, NumberType) assert.EqualError(t, err, "expected float, got: false") - _, err = toString("abc", jsonschema.IntegerType) + _, err = ToString("abc", IntegerType) assert.EqualError(t, err, "cannot convert \"abc\" to an integer") - _, err = toString("abc", "foobar") + _, err = ToString("abc", "foobar") assert.EqualError(t, err, "unknown json schema type: \"foobar\"") } func TestTemplateFromString(t *testing.T) { - v, err := fromString("true", jsonschema.BooleanType) + v, err := FromString("true", BooleanType) assert.NoError(t, err) assert.Equal(t, true, v) - v, err = fromString("abc", jsonschema.StringType) + v, err = FromString("abc", StringType) assert.NoError(t, err) assert.Equal(t, "abc", v) - v, err = fromString("1.1", jsonschema.NumberType) + v, err = FromString("1.1", NumberType) assert.NoError(t, err) // Floating point conversions are not perfect assert.True(t, (v.(float64)-1.1) < 0.000001) - v, err = fromString("12345", jsonschema.IntegerType) + v, err = FromString("12345", IntegerType) assert.NoError(t, err) assert.Equal(t, int64(12345), v) - v, err = fromString("123", jsonschema.NumberType) + v, err = FromString("123", NumberType) assert.NoError(t, err) assert.Equal(t, float64(123), v) - _, err = fromString("qrt", jsonschema.ArrayType) + _, err = FromString("qrt", ArrayType) assert.EqualError(t, err, "cannot parse string as object of type array. Value of string: \"qrt\"") - _, err = fromString("abc", jsonschema.IntegerType) + _, err = FromString("abc", IntegerType) assert.EqualError(t, err, "could not parse \"abc\" as a integer: strconv.ParseInt: parsing \"abc\": invalid syntax") - _, err = fromString("1.0", jsonschema.IntegerType) + _, err = FromString("1.0", IntegerType) assert.EqualError(t, err, "could not parse \"1.0\" as a integer: strconv.ParseInt: parsing \"1.0\": invalid syntax") - _, err = fromString("1.0", "foobar") + _, err = FromString("1.0", "foobar") assert.EqualError(t, err, "unknown json schema type: \"foobar\"") } + +func TestTemplateToStringSlice(t *testing.T) { + s, err := ToStringSlice([]any{"a", "b", "c"}, StringType) + assert.NoError(t, err) + assert.Equal(t, []string{"a", "b", "c"}, s) + + s, err = ToStringSlice([]any{1.1, 2.2, 3.3}, NumberType) + assert.NoError(t, err) + assert.Equal(t, []string{"1.1", "2.2", "3.3"}, s) +} diff --git a/libs/template/validators.go b/libs/jsonschema/validate_type.go similarity index 68% rename from libs/template/validators.go rename to libs/jsonschema/validate_type.go index 209700b63..125d6b20b 100644 --- a/libs/template/validators.go +++ b/libs/jsonschema/validate_type.go @@ -1,17 +1,15 @@ -package template +package jsonschema import ( "fmt" "reflect" "slices" - - "github.com/databricks/cli/libs/jsonschema" ) -type validator func(v any) error +type validateTypeFunc func(v any) error -func validateType(v any, fieldType jsonschema.Type) error { - validateFunc, ok := validators[fieldType] +func validateType(v any, fieldType Type) error { + validateFunc, ok := validateTypeFuncs[fieldType] if !ok { return nil } @@ -50,9 +48,9 @@ func validateInteger(v any) error { return nil } -var validators map[jsonschema.Type]validator = map[jsonschema.Type]validator{ - jsonschema.StringType: validateString, - jsonschema.BooleanType: validateBoolean, - jsonschema.IntegerType: validateInteger, - jsonschema.NumberType: validateNumber, +var validateTypeFuncs map[Type]validateTypeFunc = map[Type]validateTypeFunc{ + StringType: validateString, + BooleanType: validateBoolean, + IntegerType: validateInteger, + NumberType: validateNumber, } diff --git a/libs/template/validators_test.go b/libs/jsonschema/validate_type_test.go similarity index 75% rename from libs/template/validators_test.go rename to libs/jsonschema/validate_type_test.go index f34f037a1..36d9e5758 100644 --- a/libs/template/validators_test.go +++ b/libs/jsonschema/validate_type_test.go @@ -1,9 +1,8 @@ -package template +package jsonschema import ( "testing" - "github.com/databricks/cli/libs/jsonschema" "github.com/stretchr/testify/assert" ) @@ -77,53 +76,53 @@ func TestValidatorInt(t *testing.T) { func TestTemplateValidateType(t *testing.T) { // assert validation passing - err := validateType(int(0), jsonschema.IntegerType) + err := validateType(int(0), IntegerType) assert.NoError(t, err) - err = validateType(int32(1), jsonschema.IntegerType) + err = validateType(int32(1), IntegerType) assert.NoError(t, err) - err = validateType(int64(1), jsonschema.IntegerType) + err = validateType(int64(1), IntegerType) assert.NoError(t, err) - err = validateType(float32(1.1), jsonschema.NumberType) + err = validateType(float32(1.1), NumberType) assert.NoError(t, err) - err = validateType(float64(1.2), jsonschema.NumberType) + err = validateType(float64(1.2), NumberType) assert.NoError(t, err) - err = validateType(false, jsonschema.BooleanType) + err = validateType(false, BooleanType) assert.NoError(t, err) - err = validateType("abc", jsonschema.StringType) + err = validateType("abc", StringType) assert.NoError(t, err) // assert validation failing for integers - err = validateType(float64(1.2), jsonschema.IntegerType) + err = validateType(float64(1.2), IntegerType) assert.ErrorContains(t, err, "expected type integer, but value is 1.2") - err = validateType(true, jsonschema.IntegerType) + err = validateType(true, IntegerType) assert.ErrorContains(t, err, "expected type integer, but value is true") - err = validateType("abc", jsonschema.IntegerType) + err = validateType("abc", IntegerType) assert.ErrorContains(t, err, "expected type integer, but value is \"abc\"") // assert validation failing for floats - err = validateType(true, jsonschema.NumberType) + err = validateType(true, NumberType) assert.ErrorContains(t, err, "expected type float, but value is true") - err = validateType("abc", jsonschema.NumberType) + err = validateType("abc", NumberType) assert.ErrorContains(t, err, "expected type float, but value is \"abc\"") - err = validateType(int(1), jsonschema.NumberType) + err = validateType(int(1), NumberType) assert.ErrorContains(t, err, "expected type float, but value is 1") // assert validation failing for boolean - err = validateType(int(1), jsonschema.BooleanType) + err = validateType(int(1), BooleanType) assert.ErrorContains(t, err, "expected type boolean, but value is 1") - err = validateType(float64(1), jsonschema.BooleanType) + err = validateType(float64(1), BooleanType) assert.ErrorContains(t, err, "expected type boolean, but value is 1") - err = validateType("abc", jsonschema.BooleanType) + err = validateType("abc", BooleanType) assert.ErrorContains(t, err, "expected type boolean, but value is \"abc\"") // assert validation failing for string - err = validateType(int(1), jsonschema.StringType) + err = validateType(int(1), StringType) assert.ErrorContains(t, err, "expected type string, but value is 1") - err = validateType(float64(1), jsonschema.StringType) + err = validateType(float64(1), StringType) assert.ErrorContains(t, err, "expected type string, but value is 1") - err = validateType(false, jsonschema.StringType) + err = validateType(false, StringType) assert.ErrorContains(t, err, "expected type string, but value is false") } diff --git a/libs/template/config.go b/libs/template/config.go index 8a1ed6c82..21618ac9a 100644 --- a/libs/template/config.go +++ b/libs/template/config.go @@ -2,12 +2,11 @@ package template import ( "context" - "encoding/json" "fmt" - "os" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/jsonschema" + "golang.org/x/exp/maps" ) type config struct { @@ -26,6 +25,9 @@ func newConfig(ctx context.Context, schemaPath string) (*config, error) { return nil, err } + // Do not allow template input variables that are not defined in the schema. + schema.AdditionalProperties = false + // Return config return &config{ ctx: ctx, @@ -45,32 +47,10 @@ func validateSchema(schema *jsonschema.Schema) error { // Reads json file at path and assigns values from the file func (c *config) assignValuesFromFile(path string) error { - // Read the config file - configFromFile := make(map[string]any, 0) - b, err := os.ReadFile(path) + // Load the config file. + configFromFile, err := c.schema.LoadInstance(path) if err != nil { - return err - } - err = json.Unmarshal(b, &configFromFile) - if err != nil { - return err - } - - // Cast any integer properties, from float to integer. Required because - // the json unmarshaller treats all json numbers as floating point - for name, floatVal := range configFromFile { - property, ok := c.schema.Properties[name] - if !ok { - return fmt.Errorf("%s is not defined as an input parameter for the template", name) - } - if property.Type != jsonschema.IntegerType { - continue - } - v, err := toInteger(floatVal) - if err != nil { - return fmt.Errorf("failed to cast value %v of property %s from file %s to an integer: %w", floatVal, name, path, err) - } - configFromFile[name] = v + return fmt.Errorf("failed to load config from file %s: %w", path, err) } // Write configs from the file to the input map, not overwriting any existing @@ -91,26 +71,11 @@ func (c *config) assignDefaultValues() error { if _, ok := c.values[name]; ok { continue } - // No default value defined for the property if property.Default == nil { continue } - - // Assign default value if property is not an integer - if property.Type != jsonschema.IntegerType { - c.values[name] = property.Default - continue - } - - // Cast default value to int before assigning to an integer configuration. - // Required because untyped field Default will read all numbers as floats - // during unmarshalling - v, err := toInteger(property.Default) - if err != nil { - return fmt.Errorf("failed to cast default value %v of property %s to an integer: %w", property.Default, name, err) - } - c.values[name] = v + c.values[name] = property.Default } return nil } @@ -130,20 +95,34 @@ func (c *config) promptForValues() error { var defaultVal string var err error if property.Default != nil { - defaultVal, err = toString(property.Default, property.Type) + defaultVal, err = jsonschema.ToString(property.Default, property.Type) if err != nil { return err } } // Get user input by running the prompt - userInput, err := cmdio.Ask(c.ctx, property.Description, defaultVal) - if err != nil { - return err + var userInput string + if property.Enum != nil { + // convert list of enums to string slice + enums, err := jsonschema.ToStringSlice(property.Enum, property.Type) + if err != nil { + return err + } + userInput, err = cmdio.AskSelect(c.ctx, property.Description, enums) + if err != nil { + return err + } + } else { + userInput, err = cmdio.Ask(c.ctx, property.Description, defaultVal) + if err != nil { + return err + } + } // Convert user input string back to a value - c.values[name], err = fromString(userInput, property.Type) + c.values[name], err = jsonschema.FromString(userInput, property.Type) if err != nil { return err } @@ -163,42 +142,10 @@ func (c *config) promptOrAssignDefaultValues() error { // Validates the configuration. If passes, the configuration is ready to be used // to initialize the template. func (c *config) validate() error { - validateFns := []func() error{ - c.validateValuesDefined, - c.validateValuesType, - } - - for _, fn := range validateFns { - err := fn() - if err != nil { - return err - } - } - return nil -} - -// Validates all input properties have a user defined value assigned to them -func (c *config) validateValuesDefined() error { - for k := range c.schema.Properties { - if _, ok := c.values[k]; ok { - continue - } - return fmt.Errorf("no value has been assigned to input parameter %s", k) - } - return nil -} - -// Validates the types of all input properties values match their types defined in the schema -func (c *config) validateValuesType() error { - for k, v := range c.values { - fieldInfo, ok := c.schema.Properties[k] - if !ok { - return fmt.Errorf("%s is not defined as an input parameter for the template", k) - } - err := validateType(v, fieldInfo.Type) - if err != nil { - return fmt.Errorf("incorrect type for %s. %w", k, err) - } + // All properties in the JSON schema should have a value defined. + c.schema.Required = maps.Keys(c.schema.Properties) + if err := c.schema.ValidateInstance(c.values); err != nil { + return fmt.Errorf("validation for template input parameters failed. %w", err) } return nil } diff --git a/libs/template/config_test.go b/libs/template/config_test.go index 335242467..1b1fc3383 100644 --- a/libs/template/config_test.go +++ b/libs/template/config_test.go @@ -1,7 +1,7 @@ package template import ( - "encoding/json" + "context" "testing" "github.com/databricks/cli/libs/jsonschema" @@ -9,36 +9,14 @@ import ( "github.com/stretchr/testify/require" ) -func testSchema(t *testing.T) *jsonschema.Schema { - schemaJson := `{ - "properties": { - "int_val": { - "type": "integer", - "default": 123 - }, - "float_val": { - "type": "number" - }, - "bool_val": { - "type": "boolean" - }, - "string_val": { - "type": "string", - "default": "abc" - } - } - }` - var jsonSchema jsonschema.Schema - err := json.Unmarshal([]byte(schemaJson), &jsonSchema) +func testConfig(t *testing.T) *config { + c, err := newConfig(context.Background(), "./testdata/config-test-schema/test-schema.json") require.NoError(t, err) - return &jsonSchema + return c } func TestTemplateConfigAssignValuesFromFile(t *testing.T) { - c := config{ - schema: testSchema(t), - values: make(map[string]any), - } + c := testConfig(t) err := c.assignValuesFromFile("./testdata/config-assign-from-file/config.json") assert.NoError(t, err) @@ -49,32 +27,17 @@ func TestTemplateConfigAssignValuesFromFile(t *testing.T) { assert.Equal(t, "hello", c.values["string_val"]) } -func TestTemplateConfigAssignValuesFromFileForUnknownField(t *testing.T) { - c := config{ - schema: testSchema(t), - values: make(map[string]any), - } - - err := c.assignValuesFromFile("./testdata/config-assign-from-file-unknown-property/config.json") - assert.EqualError(t, err, "unknown_prop is not defined as an input parameter for the template") -} - func TestTemplateConfigAssignValuesFromFileForInvalidIntegerValue(t *testing.T) { - c := config{ - schema: testSchema(t), - values: make(map[string]any), - } + c := testConfig(t) err := c.assignValuesFromFile("./testdata/config-assign-from-file-invalid-int/config.json") - assert.EqualError(t, err, "failed to cast value abc of property int_val from file ./testdata/config-assign-from-file-invalid-int/config.json to an integer: cannot convert \"abc\" to an integer") + assert.EqualError(t, err, "failed to load config from file ./testdata/config-assign-from-file-invalid-int/config.json: failed to parse property int_val: cannot convert \"abc\" to an integer") } func TestTemplateConfigAssignValuesFromFileDoesNotOverwriteExistingConfigs(t *testing.T) { - c := config{ - schema: testSchema(t), - values: map[string]any{ - "string_val": "this-is-not-overwritten", - }, + c := testConfig(t) + c.values = map[string]any{ + "string_val": "this-is-not-overwritten", } err := c.assignValuesFromFile("./testdata/config-assign-from-file/config.json") @@ -87,10 +50,7 @@ func TestTemplateConfigAssignValuesFromFileDoesNotOverwriteExistingConfigs(t *te } func TestTemplateConfigAssignDefaultValues(t *testing.T) { - c := config{ - schema: testSchema(t), - values: make(map[string]any), - } + c := testConfig(t) err := c.assignDefaultValues() assert.NoError(t, err) @@ -101,65 +61,55 @@ func TestTemplateConfigAssignDefaultValues(t *testing.T) { } func TestTemplateConfigValidateValuesDefined(t *testing.T) { - c := config{ - schema: testSchema(t), - values: map[string]any{ - "int_val": 1, - "float_val": 1.0, - "bool_val": false, - }, + c := testConfig(t) + c.values = map[string]any{ + "int_val": 1, + "float_val": 1.0, + "bool_val": false, } - err := c.validateValuesDefined() - assert.EqualError(t, err, "no value has been assigned to input parameter string_val") + err := c.validate() + assert.EqualError(t, err, "validation for template input parameters failed. no value provided for required property string_val") } func TestTemplateConfigValidateTypeForValidConfig(t *testing.T) { - c := &config{ - schema: testSchema(t), - values: map[string]any{ - "int_val": 1, - "float_val": 1.1, - "bool_val": true, - "string_val": "abcd", - }, + c := testConfig(t) + c.values = map[string]any{ + "int_val": 1, + "float_val": 1.1, + "bool_val": true, + "string_val": "abcd", } - err := c.validateValuesType() - assert.NoError(t, err) - - err = c.validate() + err := c.validate() assert.NoError(t, err) } func TestTemplateConfigValidateTypeForUnknownField(t *testing.T) { - c := &config{ - schema: testSchema(t), - values: map[string]any{ - "unknown_prop": 1, - }, + c := testConfig(t) + c.values = map[string]any{ + "unknown_prop": 1, + "int_val": 1, + "float_val": 1.1, + "bool_val": true, + "string_val": "abcd", } - err := c.validateValuesType() - assert.EqualError(t, err, "unknown_prop is not defined as an input parameter for the template") + err := c.validate() + assert.EqualError(t, err, "validation for template input parameters failed. property unknown_prop is not defined in the schema") } func TestTemplateConfigValidateTypeForInvalidType(t *testing.T) { - c := &config{ - schema: testSchema(t), - values: map[string]any{ - "int_val": "this-should-be-an-int", - "float_val": 1.1, - "bool_val": true, - "string_val": "abcd", - }, + c := testConfig(t) + c.values = map[string]any{ + "int_val": "this-should-be-an-int", + "float_val": 1.1, + "bool_val": true, + "string_val": "abcd", } - err := c.validateValuesType() - assert.EqualError(t, err, `incorrect type for int_val. expected type integer, but value is "this-should-be-an-int"`) - - err = c.validate() - assert.EqualError(t, err, `incorrect type for int_val. expected type integer, but value is "this-should-be-an-int"`) + err := c.validate() + assert.EqualError(t, err, "validation for template input parameters failed. incorrect type for property int_val: expected type integer, but value is \"this-should-be-an-int\"") } func TestTemplateValidateSchema(t *testing.T) { @@ -192,3 +142,30 @@ func TestTemplateValidateSchema(t *testing.T) { err = validateSchema(toSchema("array")) assert.EqualError(t, err, "property type array is not supported by bundle templates") } + +func TestTemplateEnumValidation(t *testing.T) { + schema := jsonschema.Schema{ + Properties: map[string]*jsonschema.Schema{ + "abc": { + Type: "integer", + Enum: []any{1, 2, 3, 4}, + }, + }, + } + + c := &config{ + schema: &schema, + values: map[string]any{ + "abc": 5, + }, + } + assert.EqualError(t, c.validate(), "validation for template input parameters failed. expected value of property abc to be one of [1 2 3 4]. Found: 5") + + c = &config{ + schema: &schema, + values: map[string]any{ + "abc": 4, + }, + } + assert.NoError(t, c.validate()) +} diff --git a/libs/template/helpers.go b/libs/template/helpers.go index 29abbe21c..317522703 100644 --- a/libs/template/helpers.go +++ b/libs/template/helpers.go @@ -26,9 +26,10 @@ type pair struct { v any } +var cachedUser *iam.User +var cachedIsServicePrincipal *bool + func loadHelpers(ctx context.Context) template.FuncMap { - var user *iam.User - var is_service_principal *bool w := root.WorkspaceClient(ctx) return template.FuncMap{ "fail": func(format string, args ...any) (any, error) { @@ -80,32 +81,32 @@ func loadHelpers(ctx context.Context) template.FuncMap { return w.Config.Host, nil }, "user_name": func() (string, error) { - if user == nil { + if cachedUser == nil { var err error - user, err = w.CurrentUser.Me(ctx) + cachedUser, err = w.CurrentUser.Me(ctx) if err != nil { return "", err } } - result := user.UserName + result := cachedUser.UserName if result == "" { - result = user.Id + result = cachedUser.Id } return result, nil }, "is_service_principal": func() (bool, error) { - if is_service_principal != nil { - return *is_service_principal, nil + if cachedIsServicePrincipal != nil { + return *cachedIsServicePrincipal, nil } - if user == nil { + if cachedUser == nil { var err error - user, err = w.CurrentUser.Me(ctx) + cachedUser, err = w.CurrentUser.Me(ctx) if err != nil { return false, err } } - result := auth.IsServicePrincipal(user.Id) - is_service_principal = &result + result := auth.IsServicePrincipal(cachedUser.Id) + cachedIsServicePrincipal = &result return result, nil }, } diff --git a/libs/template/renderer.go b/libs/template/renderer.go index f4bd99d2c..f674ea0fb 100644 --- a/libs/template/renderer.go +++ b/libs/template/renderer.go @@ -9,6 +9,7 @@ import ( "path" "path/filepath" "slices" + "sort" "strings" "text/template" @@ -214,17 +215,22 @@ func (r *renderer) walk() error { // Add skip function, which accumulates skip patterns relative to current // directory r.baseTemplate.Funcs(template.FuncMap{ - "skip": func(relPattern string) string { + "skip": func(relPattern string) (string, error) { // patterns are specified relative to current directory of the file // the {{skip}} function is called from. - pattern := path.Join(currentDirectory, relPattern) + patternRaw := path.Join(currentDirectory, relPattern) + pattern, err := r.executeTemplate(patternRaw) + if err != nil { + return "", err + } + if !slices.Contains(r.skipPatterns, pattern) { logger.Infof(r.ctx, "adding skip pattern: %s", pattern) r.skipPatterns = append(r.skipPatterns, pattern) } // return empty string will print nothing at function call site // when executing the template - return "" + return "", nil }, }) @@ -239,6 +245,10 @@ func (r *renderer) walk() error { if err != nil { return err } + // Sort by name to ensure deterministic ordering + sort.Slice(entries, func(i, j int) bool { + return entries[i].Name() < entries[j].Name() + }) for _, entry := range entries { if entry.IsDir() { // Add to slice, for BFS traversal diff --git a/libs/template/renderer_test.go b/libs/template/renderer_test.go index a2e5675e8..21dd1e4fa 100644 --- a/libs/template/renderer_test.go +++ b/libs/template/renderer_test.go @@ -12,7 +12,14 @@ import ( "testing" "text/template" + "github.com/databricks/cli/bundle" + bundleConfig "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/config/mutator" + "github.com/databricks/cli/bundle/phases" "github.com/databricks/cli/cmd/root" + "github.com/databricks/databricks-sdk-go" + workspaceConfig "github.com/databricks/databricks-sdk-go/config" + "github.com/databricks/databricks-sdk-go/service/iam" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -29,6 +36,95 @@ func assertFilePermissions(t *testing.T, path string, perm fs.FileMode) { assert.Equal(t, perm, info.Mode().Perm()) } +func assertBuiltinTemplateValid(t *testing.T, settings map[string]any, target string, isServicePrincipal bool, build bool, tempDir string) { + ctx := context.Background() + + templatePath, err := prepareBuiltinTemplates("default-python", tempDir) + require.NoError(t, err) + + w := &databricks.WorkspaceClient{ + Config: &workspaceConfig.Config{Host: "https://myhost.com"}, + } + + // Prepare helpers + cachedUser = &iam.User{UserName: "user@domain.com"} + cachedIsServicePrincipal = &isServicePrincipal + ctx = root.SetWorkspaceClient(ctx, w) + helpers := loadHelpers(ctx) + + renderer, err := newRenderer(ctx, settings, helpers, templatePath, "./testdata/template-in-path/library", tempDir) + require.NoError(t, err) + + // Evaluate template + err = renderer.walk() + require.NoError(t, err) + err = renderer.persistToDisk() + require.NoError(t, err) + b, err := bundle.Load(ctx, filepath.Join(tempDir, "template", "my_project")) + require.NoError(t, err) + + // Apply initialize / validation mutators + b.Config.Workspace.CurrentUser = &bundleConfig.User{User: cachedUser} + b.WorkspaceClient() + b.Config.Bundle.Terraform = &bundleConfig.Terraform{ + ExecPath: "sh", + } + err = bundle.Apply(ctx, b, bundle.Seq( + bundle.Seq(mutator.DefaultMutators()...), + mutator.SelectTarget(target), + phases.Initialize(), + )) + require.NoError(t, err) + + // Apply build mutator + if build { + err = bundle.Apply(ctx, b, phases.Build()) + require.NoError(t, err) + } +} + +func TestBuiltinTemplateValid(t *testing.T) { + // Test option combinations + options := []string{"yes", "no"} + isServicePrincipal := false + build := false + for _, includeNotebook := range options { + for _, includeDlt := range options { + for _, includePython := range options { + for _, isServicePrincipal := range []bool{true, false} { + config := map[string]any{ + "project_name": "my_project", + "include_notebook": includeNotebook, + "include_dlt": includeDlt, + "include_python": includePython, + } + tempDir := t.TempDir() + assertBuiltinTemplateValid(t, config, "dev", isServicePrincipal, build, tempDir) + } + } + } + } + + // Test prod mode + build + config := map[string]any{ + "project_name": "my_project", + "include_notebook": "yes", + "include_dlt": "yes", + "include_python": "yes", + } + isServicePrincipal = false + build = true + + // On Windows, we can't always remove the resulting temp dir since background + // processes might have it open, so we use 'defer' for a best-effort cleanup + tempDir, err := os.MkdirTemp("", "templates") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + assertBuiltinTemplateValid(t, config, "prod", isServicePrincipal, build, tempDir) + defer os.RemoveAll(tempDir) +} + func TestRendererWithAssociatedTemplateInLibrary(t *testing.T) { tmpDir := t.TempDir() diff --git a/libs/template/templates/default-python/databricks_template_schema.json b/libs/template/templates/default-python/databricks_template_schema.json index 3220e9a6f..22c65f309 100644 --- a/libs/template/templates/default-python/databricks_template_schema.json +++ b/libs/template/templates/default-python/databricks_template_schema.json @@ -3,7 +3,32 @@ "project_name": { "type": "string", "default": "my_project", - "description": "Unique name for this project" + "description": "Unique name for this project", + "order": 1 + }, + "include_notebook": { + "todo": "use an enum here, see https://github.com/databricks/cli/pull/668", + "type": "string", + "default": "yes", + "pattern": "^(yes|no)$", + "description": "Include a stub (sample) notebook in 'my_project/src'", + "order": 2 + }, + "include_dlt": { + "todo": "use an enum here, see https://github.com/databricks/cli/pull/668", + "type": "string", + "default": "yes", + "pattern": "^(yes|no)$", + "description": "Include a stub (sample) DLT pipeline in 'my_project/src'", + "order": 3 + }, + "include_python": { + "todo": "use an enum here, see https://github.com/databricks/cli/pull/668", + "type": "string", + "default": "yes", + "pattern": "^(yes|no)$", + "description": "Include a stub (sample) Python package 'my_project/src'", + "order": 4 } } } diff --git a/libs/template/templates/default-python/defaults.json b/libs/template/templates/default-python/defaults.json deleted file mode 100644 index 99ecd36d2..000000000 --- a/libs/template/templates/default-python/defaults.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "project_name": "my_project" -} diff --git a/libs/template/templates/default-python/template/__preamble.tmpl b/libs/template/templates/default-python/template/__preamble.tmpl new file mode 100644 index 000000000..a86d3bffd --- /dev/null +++ b/libs/template/templates/default-python/template/__preamble.tmpl @@ -0,0 +1,38 @@ +# Preamble + +This file only template directives; it is skipped for the actual output. + +{{skip "__preamble"}} + +{{ $value := .project_name }} +{{with (regexp "^[A-Za-z0-9_]*$")}} + {{if not (.MatchString $value)}} + {{fail "Invalid project_name: %s. Must consist of letter and underscores only." $value}} + {{end}} +{{end}} + +{{$notDLT := not (eq .include_dlt "yes")}} +{{$notNotebook := not (eq .include_notebook "yes")}} +{{$notPython := not (eq .include_python "yes")}} + +{{if $notPython}} + {{skip "{{.project_name}}/src/{{.project_name}}"}} + {{skip "{{.project_name}}/tests/main_test.py"}} + {{skip "{{.project_name}}/setup.py"}} + {{skip "{{.project_name}}/pytest.ini"}} +{{end}} + +{{if $notDLT}} + {{skip "{{.project_name}}/src/dlt_pipeline.ipynb"}} + {{skip "{{.project_name}}/resources/{{.project_name}}_pipeline.yml"}} +{{end}} + +{{if $notNotebook}} + {{skip "{{.project_name}}/src/notebook.ipynb"}} +{{end}} + +{{if (and $notDLT $notNotebook $notPython)}} + {{skip "{{.project_name}}/resources/{{.project_name}}_job.yml"}} +{{else}} + {{skip "{{.project_name}}/resources/.gitkeep"}} +{{end}} diff --git a/libs/template/templates/default-python/template/{{.project_name}}/README.md.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/README.md.tmpl index 4c89435b4..1bcd7af41 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/README.md.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/README.md.tmpl @@ -20,7 +20,7 @@ The '{{.project_name}}' project was generated by using the default-python templa This deploys everything that's defined for this project. For example, the default template would deploy a job called - `[dev yourname] {{.project_name}}-job` to your workspace. + `[dev yourname] {{.project_name}}_job` to your workspace. You can find that job by opening your workpace and clicking on **Workflows**. 4. Similarly, to deploy a production copy, type: @@ -28,10 +28,17 @@ The '{{.project_name}}' project was generated by using the default-python templa $ databricks bundle deploy --target prod ``` -5. Optionally, install developer tools such as the Databricks extension for Visual Studio Code from - https://docs.databricks.com/dev-tools/vscode-ext.html. Or read the "getting started" documentation for - **Databricks Connect** for instructions on running the included Python code from a different IDE. +5. To run a job or pipeline, use the "run" comand: + ``` + $ databricks bundle run {{.project_name}}_job + ``` -6. For documentation on the Databricks asset bundles format used +6. Optionally, install developer tools such as the Databricks extension for Visual Studio Code from + https://docs.databricks.com/dev-tools/vscode-ext.html. +{{- if (eq .include_python "yes") }} Or read the "getting started" documentation for + **Databricks Connect** for instructions on running the included Python code from a different IDE. +{{- end}} + +7. For documentation on the Databricks asset bundles format used for this project, and for CI/CD configuration, see https://docs.databricks.com/dev-tools/bundles/index.html. diff --git a/libs/template/templates/default-python/template/{{.project_name}}/databricks.yml.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/databricks.yml.tmpl index 48aef0ea3..7fbf4da4c 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/databricks.yml.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/databricks.yml.tmpl @@ -7,10 +7,10 @@ include: - resources/*.yml targets: - # The 'dev' target, used development purposes. + # The 'dev' target, used for development purposes. # Whenever a developer deploys using 'dev', they get their own copy. dev: - # We use 'mode: development' to make everything deployed to this target gets a prefix + # We use 'mode: development' to make sure everything deployed to this target gets a prefix # like '[dev my_user_name]'. Setting this mode also disables any schedules and # automatic triggers for jobs and enables the 'development' mode for Delta Live Tables pipelines. mode: development diff --git a/libs/template/templates/default-python/template/{{.project_name}}/fixtures/.gitkeep.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/fixtures/.gitkeep.tmpl index 361c681f9..ee9570302 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/fixtures/.gitkeep.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/fixtures/.gitkeep.tmpl @@ -17,7 +17,7 @@ def get_absolute_path(*relative_parts): if 'dbutils' in globals(): base_dir = os.path.dirname(dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()) # type: ignore path = os.path.normpath(os.path.join(base_dir, *relative_parts)) - return path if path.startswith("/Workspace") else os.path.join("/Workspace", path) + return path if path.startswith("/Workspace") else "/Workspace" + path else: return os.path.join(*relative_parts) diff --git a/libs/template/templates/default-python/template/{{.project_name}}/resources/.gitkeep b/libs/template/templates/default-python/template/{{.project_name}}/resources/.gitkeep new file mode 100644 index 000000000..3e09c14c1 --- /dev/null +++ b/libs/template/templates/default-python/template/{{.project_name}}/resources/.gitkeep @@ -0,0 +1 @@ +This folder is reserved for Databricks Asset Bundles resource definitions. diff --git a/libs/template/templates/default-python/template/{{.project_name}}/resources/{{.project_name}}_job.yml.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/resources/{{.project_name}}_job.yml.tmpl index f8116cdfc..1792f9479 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/resources/{{.project_name}}_job.yml.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/resources/{{.project_name}}_job.yml.tmpl @@ -1,6 +1,5 @@ # The main job for {{.project_name}} resources: - jobs: {{.project_name}}_job: name: {{.project_name}}_job @@ -10,20 +9,41 @@ resources: timezone_id: Europe/Amsterdam {{- if not is_service_principal}} + email_notifications: on_failure: - {{user_name}} + + {{else}} + {{end -}} tasks: + {{- if eq .include_notebook "yes" }} - task_key: notebook_task job_cluster_key: job_cluster notebook_task: notebook_path: ../src/notebook.ipynb - - - task_key: python_wheel_task + {{end -}} + {{- if (eq .include_dlt "yes") }} + - task_key: refresh_pipeline + {{- if (eq .include_notebook "yes" )}} depends_on: - task_key: notebook_task + {{- end}} + pipeline_task: + {{- /* TODO: we should find a way that doesn't use magics for the below, like ./{{project_name}}_pipeline.yml */}} + pipeline_id: ${resources.pipelines.{{.project_name}}_pipeline.id} + {{end -}} + {{- if (eq .include_python "yes") }} + - task_key: main_task + {{- if (eq .include_dlt "yes") }} + depends_on: + - task_key: refresh_pipeline + {{- else if (eq .include_notebook "yes" )}} + depends_on: + - task_key: notebook_task + {{end}} job_cluster_key: job_cluster python_wheel_task: package_name: {{.project_name}} @@ -31,6 +51,8 @@ resources: libraries: - whl: ../dist/*.whl + {{else}} + {{end -}} job_clusters: - job_cluster_key: job_cluster new_cluster: diff --git a/libs/template/templates/default-python/template/{{.project_name}}/resources/{{.project_name}}_pipeline.yml.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/resources/{{.project_name}}_pipeline.yml.tmpl new file mode 100644 index 000000000..4b8f74d17 --- /dev/null +++ b/libs/template/templates/default-python/template/{{.project_name}}/resources/{{.project_name}}_pipeline.yml.tmpl @@ -0,0 +1,12 @@ +# The main pipeline for {{.project_name}} +resources: + pipelines: + {{.project_name}}_pipeline: + name: {{.project_name}}_pipeline + target: {{.project_name}}_${bundle.environment} + libraries: + - notebook: + path: ../src/dlt_pipeline.ipynb + + configuration: + bundle.sourcePath: /Workspace/${workspace.file_path}/src diff --git a/libs/template/templates/default-python/template/{{.project_name}}/scratch/exploration.ipynb b/libs/template/templates/default-python/template/{{.project_name}}/scratch/exploration.ipynb.tmpl similarity index 84% rename from libs/template/templates/default-python/template/{{.project_name}}/scratch/exploration.ipynb rename to libs/template/templates/default-python/template/{{.project_name}}/scratch/exploration.ipynb.tmpl index 2ee36c3c1..04bb261cd 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/scratch/exploration.ipynb +++ b/libs/template/templates/default-python/template/{{.project_name}}/scratch/exploration.ipynb.tmpl @@ -17,11 +17,15 @@ }, "outputs": [], "source": [ + {{- if (eq .include_python "yes") }} "import sys\n", "sys.path.append('../src')\n", - "from project import main\n", + "from {{.project_name}} import main\n", "\n", - "main.taxis.show(10)" + "main.get_taxis().show(10)" + {{else}} + "spark.range(10)" + {{end -}} ] } ], diff --git a/libs/template/templates/default-python/template/{{.project_name}}/setup.py.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/setup.py.tmpl index 93f4e9ff9..efd598820 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/setup.py.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/setup.py.tmpl @@ -15,7 +15,7 @@ setup( name="{{.project_name}}", version={{.project_name}}.__version__, url="https://databricks.com", - author="{{.user_name}}", + author="{{user_name}}", description="my test wheel", packages=find_packages(where='./src'), package_dir={'': 'src'}, diff --git a/libs/template/templates/default-python/template/{{.project_name}}/src/dlt_pipeline.ipynb.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/src/dlt_pipeline.ipynb.tmpl new file mode 100644 index 000000000..4f50294f6 --- /dev/null +++ b/libs/template/templates/default-python/template/{{.project_name}}/src/dlt_pipeline.ipynb.tmpl @@ -0,0 +1,102 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "9a626959-61c8-4bba-84d2-2a4ecab1f7ec", + "showTitle": false, + "title": "" + } + }, + "source": [ + "# DLT pipeline\n", + "\n", + "This Delta Live Tables (DLT) definition is executed using a pipeline defined in resources/{{.project_name}}_pipeline.yml." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "9198e987-5606-403d-9f6d-8f14e6a4017f", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + {{- if (eq .include_python "yes") }} + "# Import DLT and src/{{.project_name}}\n", + "import dlt\n", + "import sys\n", + "sys.path.append(spark.conf.get(\"bundle.sourcePath\", \".\"))\n", + "from pyspark.sql.functions import expr\n", + "from {{.project_name}} import main" + {{else}} + "import dlt\n", + "from pyspark.sql.functions import expr\n", + "from pyspark.sql import SparkSession\n", + "spark = SparkSession.builder.getOrCreate()" + {{end -}} + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "3fc19dba-61fd-4a89-8f8c-24fee63bfb14", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + {{- if (eq .include_python "yes") }} + "@dlt.view\n", + "def taxi_raw():\n", + " return main.get_taxis()\n", + {{else}} + "\n", + "@dlt.view\n", + "def taxi_raw():\n", + " return spark.read.format(\"json\").load(\"/databricks-datasets/nyctaxi/sample/json/\")\n", + {{end -}} + "\n", + "@dlt.table\n", + "def filtered_taxis():\n", + " return dlt.read(\"taxi_raw\").filter(expr(\"fare_amount < 30\"))" + ] + } + ], + "metadata": { + "application/vnd.databricks.v1+notebook": { + "dashboards": [], + "language": "python", + "notebookMetadata": { + "pythonIndentUnit": 2 + }, + "notebookName": "dlt_pipeline", + "widgets": {} + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/libs/template/templates/default-python/template/{{.project_name}}/src/notebook.ipynb.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/src/notebook.ipynb.tmpl index 26c743032..0ab61db2c 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/src/notebook.ipynb.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/src/notebook.ipynb.tmpl @@ -14,7 +14,7 @@ "source": [ "# Default notebook\n", "\n", - "This default notebook is executed using Databricks Workflows as defined in resources/{{.my_project}}_job.yml." + "This default notebook is executed using Databricks Workflows as defined in resources/{{.project_name}}_job.yml." ] }, { @@ -34,9 +34,13 @@ }, "outputs": [], "source": [ + {{- if (eq .include_python "yes") }} "from {{.project_name}} import main\n", "\n", - "main.get_taxis().show(10)\n" + "main.get_taxis().show(10)" + {{else}} + "spark.range(10)" + {{end -}} ] } ], diff --git a/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl index 92afccc6c..f1750046a 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl @@ -2,4 +2,4 @@ from {{.project_name}} import main def test_main(): taxis = main.get_taxis() - assert taxis.count() == 5 + assert taxis.count() > 5 diff --git a/libs/template/testdata/config-test-schema/test-schema.json b/libs/template/testdata/config-test-schema/test-schema.json new file mode 100644 index 000000000..41eb82519 --- /dev/null +++ b/libs/template/testdata/config-test-schema/test-schema.json @@ -0,0 +1,18 @@ +{ + "properties": { + "int_val": { + "type": "integer", + "default": 123 + }, + "float_val": { + "type": "number" + }, + "bool_val": { + "type": "boolean" + }, + "string_val": { + "type": "string", + "default": "abc" + } + } +} diff --git a/main.go b/main.go index a4b8aabd6..8c8516d9d 100644 --- a/main.go +++ b/main.go @@ -1,10 +1,12 @@ package main import ( + "context" + "github.com/databricks/cli/cmd" "github.com/databricks/cli/cmd/root" ) func main() { - root.Execute(cmd.New()) + root.Execute(cmd.New(context.Background())) } diff --git a/main_test.go b/main_test.go index 6a5d19448..34ecdca0f 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,7 @@ package main import ( + "context" "testing" "github.com/databricks/cli/cmd" @@ -15,7 +16,7 @@ func TestCommandsDontUseUnderscoreInName(t *testing.T) { // This test lives in the main package because this is where // all commands are imported. // - queue := []*cobra.Command{cmd.New()} + queue := []*cobra.Command{cmd.New(context.Background())} for len(queue) > 0 { cmd := queue[0] assert.NotContains(t, cmd.Name(), "_") diff --git a/python/utils.go b/python/utils.go index a8408fae2..47d5462d2 100644 --- a/python/utils.go +++ b/python/utils.go @@ -30,6 +30,8 @@ func FindFilesWithSuffixInPath(dir, suffix string) []string { log.Debugf(context.Background(), "open dir %s: %s", dir, err) return nil } + defer f.Close() + entries, err := f.ReadDir(0) if err != nil { log.Debugf(context.Background(), "read dir %s: %s", dir, err)