Merge branch 'main' of github.com:databricks/cli into mutator-with-wrappers

This commit is contained in:
kartikgupta-db 2023-09-13 12:59:19 +02:00
commit f45935f8ae
No known key found for this signature in database
GPG Key ID: 6AD5FA11FACDEA39
138 changed files with 2976 additions and 563 deletions

View File

@ -1,5 +1,142 @@
# Version changelog # Version changelog
## 0.205.0
This release marks the public preview phase of Databricks Asset Bundles.
For more information, please refer to our online documentation at
https://docs.databricks.com/en/dev-tools/bundles/.
CLI:
* Prompt once for a client profile ([#727](https://github.com/databricks/cli/pull/727)).
Bundles:
* Use clearer error message when no interpolation value is found. ([#764](https://github.com/databricks/cli/pull/764)).
* Use interactive prompt to select resource to run if not specified ([#762](https://github.com/databricks/cli/pull/762)).
* Add documentation link bundle command group description ([#770](https://github.com/databricks/cli/pull/770)).
## 0.204.1
Bundles:
* Fix conversion of job parameters ([#744](https://github.com/databricks/cli/pull/744)).
* Add schema and config validation to jsonschema package ([#740](https://github.com/databricks/cli/pull/740)).
* Support Model Serving Endpoints in bundles ([#682](https://github.com/databricks/cli/pull/682)).
* Do not include empty output in job run output ([#749](https://github.com/databricks/cli/pull/749)).
* Fixed marking libraries from DBFS as remote ([#750](https://github.com/databricks/cli/pull/750)).
* Process only Python wheel tasks which have local libraries used ([#751](https://github.com/databricks/cli/pull/751)).
* Add enum support for bundle templates ([#668](https://github.com/databricks/cli/pull/668)).
* Apply Python wheel trampoline if workspace library is used ([#755](https://github.com/databricks/cli/pull/755)).
* List available targets when incorrect target passed ([#756](https://github.com/databricks/cli/pull/756)).
* Make bundle and sync fields optional ([#757](https://github.com/databricks/cli/pull/757)).
* Consolidate environment variable interaction ([#747](https://github.com/databricks/cli/pull/747)).
Internal:
* Update Go SDK to v0.19.1 ([#759](https://github.com/databricks/cli/pull/759)).
## 0.204.0
This release includes permission related commands for a subset of workspace
services where they apply. These complement the `permissions` command and
do not require specification of the object type to work with, as that is
implied by the command they are nested under.
CLI:
* Group permission related commands ([#730](https://github.com/databricks/cli/pull/730)).
Bundles:
* Fixed artifact file uploading on Windows and wheel execution on DBR 13.3 ([#722](https://github.com/databricks/cli/pull/722)).
* Make resource and artifact paths in bundle config relative to config folder ([#708](https://github.com/databricks/cli/pull/708)).
* Add support for ordering of input prompts ([#662](https://github.com/databricks/cli/pull/662)).
* Fix IsServicePrincipal() only working for workspace admins ([#732](https://github.com/databricks/cli/pull/732)).
* databricks bundle init template v1 ([#686](https://github.com/databricks/cli/pull/686)).
* databricks bundle init template v2: optional stubs, DLT support ([#700](https://github.com/databricks/cli/pull/700)).
* Show 'databricks bundle init' template in CLI prompt ([#725](https://github.com/databricks/cli/pull/725)).
* Include in set of environment variables to pass along. ([#736](https://github.com/databricks/cli/pull/736)).
Internal:
* Update Go SDK to v0.19.0 ([#729](https://github.com/databricks/cli/pull/729)).
* Replace API call to test configuration with dummy authenticate call ([#728](https://github.com/databricks/cli/pull/728)).
API Changes:
* Changed `databricks account storage-credentials create` command to return .
* Changed `databricks account storage-credentials get` command to return .
* Changed `databricks account storage-credentials list` command to return .
* Changed `databricks account storage-credentials update` command to return .
* Changed `databricks connections create` command with new required argument order.
* Changed `databricks connections update` command with new required argument order.
* Changed `databricks volumes create` command with new required argument order.
* Added `databricks artifact-allowlists` command group.
* Added `databricks model-versions` command group.
* Added `databricks registered-models` command group.
* Added `databricks cluster-policies get-permission-levels` command.
* Added `databricks cluster-policies get-permissions` command.
* Added `databricks cluster-policies set-permissions` command.
* Added `databricks cluster-policies update-permissions` command.
* Added `databricks clusters get-permission-levels` command.
* Added `databricks clusters get-permissions` command.
* Added `databricks clusters set-permissions` command.
* Added `databricks clusters update-permissions` command.
* Added `databricks instance-pools get-permission-levels` command.
* Added `databricks instance-pools get-permissions` command.
* Added `databricks instance-pools set-permissions` command.
* Added `databricks instance-pools update-permissions` command.
* Added `databricks files` command group.
* Changed `databricks permissions set` command to start returning .
* Changed `databricks permissions update` command to start returning .
* Added `databricks users get-permission-levels` command.
* Added `databricks users get-permissions` command.
* Added `databricks users set-permissions` command.
* Added `databricks users update-permissions` command.
* Added `databricks jobs get-permission-levels` command.
* Added `databricks jobs get-permissions` command.
* Added `databricks jobs set-permissions` command.
* Added `databricks jobs update-permissions` command.
* Changed `databricks experiments get-by-name` command to return .
* Changed `databricks experiments get-experiment` command to return .
* Added `databricks experiments delete-runs` command.
* Added `databricks experiments get-permission-levels` command.
* Added `databricks experiments get-permissions` command.
* Added `databricks experiments restore-runs` command.
* Added `databricks experiments set-permissions` command.
* Added `databricks experiments update-permissions` command.
* Added `databricks model-registry get-permission-levels` command.
* Added `databricks model-registry get-permissions` command.
* Added `databricks model-registry set-permissions` command.
* Added `databricks model-registry update-permissions` command.
* Added `databricks pipelines get-permission-levels` command.
* Added `databricks pipelines get-permissions` command.
* Added `databricks pipelines set-permissions` command.
* Added `databricks pipelines update-permissions` command.
* Added `databricks serving-endpoints get-permission-levels` command.
* Added `databricks serving-endpoints get-permissions` command.
* Added `databricks serving-endpoints set-permissions` command.
* Added `databricks serving-endpoints update-permissions` command.
* Added `databricks token-management get-permission-levels` command.
* Added `databricks token-management get-permissions` command.
* Added `databricks token-management set-permissions` command.
* Added `databricks token-management update-permissions` command.
* Changed `databricks dashboards create` command with new required argument order.
* Added `databricks warehouses get-permission-levels` command.
* Added `databricks warehouses get-permissions` command.
* Added `databricks warehouses set-permissions` command.
* Added `databricks warehouses update-permissions` command.
* Added `databricks dashboard-widgets` command group.
* Added `databricks query-visualizations` command group.
* Added `databricks repos get-permission-levels` command.
* Added `databricks repos get-permissions` command.
* Added `databricks repos set-permissions` command.
* Added `databricks repos update-permissions` command.
* Added `databricks secrets get-secret` command.
* Added `databricks workspace get-permission-levels` command.
* Added `databricks workspace get-permissions` command.
* Added `databricks workspace set-permissions` command.
* Added `databricks workspace update-permissions` command.
OpenAPI commit 09a7fa63d9ae243e5407941f200960ca14d48b07 (2023-09-04)
## 0.203.3 ## 0.203.3
Bundles: Bundles:

View File

@ -105,6 +105,7 @@ func TestUploadArtifactFileToCorrectRemotePath(t *testing.T) {
b.WorkspaceClient().Workspace.WithImpl(MockWorkspaceService{}) b.WorkspaceClient().Workspace.WithImpl(MockWorkspaceService{})
artifact := &config.Artifact{ artifact := &config.Artifact{
Type: "whl",
Files: []config.ArtifactFile{ Files: []config.ArtifactFile{
{ {
Source: whlPath, Source: whlPath,
@ -118,4 +119,5 @@ func TestUploadArtifactFileToCorrectRemotePath(t *testing.T) {
err := uploadArtifact(context.Background(), artifact, b) err := uploadArtifact(context.Background(), artifact, b)
require.NoError(t, err) require.NoError(t, err)
require.Regexp(t, regexp.MustCompile("/Users/test@databricks.com/whatever/.internal/[a-z0-9]+/test.whl"), artifact.Files[0].RemotePath) require.Regexp(t, regexp.MustCompile("/Users/test@databricks.com/whatever/.internal/[a-z0-9]+/test.whl"), artifact.Files[0].RemotePath)
require.Regexp(t, regexp.MustCompile("/Workspace/Users/test@databricks.com/whatever/.internal/[a-z0-9]+/test.whl"), artifact.Files[0].Libraries[0].Whl)
} }

View File

@ -27,9 +27,9 @@ func (m *detectPkg) Name() string {
} }
func (m *detectPkg) Apply(ctx context.Context, b *bundle.Bundle) error { func (m *detectPkg) Apply(ctx context.Context, b *bundle.Bundle) error {
wheelTasks := libraries.FindAllWheelTasks(b) wheelTasks := libraries.FindAllWheelTasksWithLocalLibraries(b)
if len(wheelTasks) == 0 { if len(wheelTasks) == 0 {
log.Infof(ctx, "No wheel tasks in databricks.yml config, skipping auto detect") log.Infof(ctx, "No local wheel tasks in databricks.yml config, skipping auto detect")
return nil return nil
} }
cmdio.LogString(ctx, "artifacts.whl.AutoDetect: Detecting Python wheel project...") cmdio.LogString(ctx, "artifacts.whl.AutoDetect: Detecting Python wheel project...")

View File

@ -26,7 +26,7 @@ func (*fromLibraries) Apply(ctx context.Context, b *bundle.Bundle) error {
return nil return nil
} }
tasks := libraries.FindAllWheelTasks(b) tasks := libraries.FindAllWheelTasksWithLocalLibraries(b)
for _, task := range tasks { for _, task := range tasks {
for _, lib := range task.Libraries { for _, lib := range task.Libraries {
matches, err := filepath.Glob(filepath.Join(b.Config.Path, lib.Whl)) matches, err := filepath.Glob(filepath.Join(b.Config.Path, lib.Whl))

View File

@ -14,6 +14,7 @@ import (
"sync" "sync"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/env"
"github.com/databricks/cli/folders" "github.com/databricks/cli/folders"
"github.com/databricks/cli/libs/git" "github.com/databricks/cli/libs/git"
"github.com/databricks/cli/libs/locker" "github.com/databricks/cli/libs/locker"
@ -37,6 +38,10 @@ type Bundle struct {
// Stores an initialized copy of this bundle's Terraform wrapper. // Stores an initialized copy of this bundle's Terraform wrapper.
Terraform *tfexec.Terraform Terraform *tfexec.Terraform
// Indicates that the Terraform definition based on this bundle is empty,
// i.e. that it would deploy no resources.
TerraformHasNoResources bool
// Stores the locker responsible for acquiring/releasing a deployment lock. // Stores the locker responsible for acquiring/releasing a deployment lock.
Locker *locker.Locker Locker *locker.Locker
@ -47,8 +52,6 @@ type Bundle struct {
AutoApprove bool AutoApprove bool
} }
const ExtraIncludePathsKey string = "DATABRICKS_BUNDLE_INCLUDES"
func Load(ctx context.Context, path string) (*Bundle, error) { func Load(ctx context.Context, path string) (*Bundle, error) {
bundle := &Bundle{} bundle := &Bundle{}
stat, err := os.Stat(path) stat, err := os.Stat(path)
@ -57,9 +60,9 @@ func Load(ctx context.Context, path string) (*Bundle, error) {
} }
configFile, err := config.FileNames.FindInPath(path) configFile, err := config.FileNames.FindInPath(path)
if err != nil { if err != nil {
_, hasIncludePathEnv := os.LookupEnv(ExtraIncludePathsKey) _, hasRootEnv := env.Root(ctx)
_, hasBundleRootEnv := os.LookupEnv(envBundleRoot) _, hasIncludesEnv := env.Includes(ctx)
if hasIncludePathEnv && hasBundleRootEnv && stat.IsDir() { if hasRootEnv && hasIncludesEnv && stat.IsDir() {
log.Debugf(ctx, "No bundle configuration; using bundle root: %s", path) log.Debugf(ctx, "No bundle configuration; using bundle root: %s", path)
bundle.Config = config.Root{ bundle.Config = config.Root{
Path: path, Path: path,
@ -82,7 +85,7 @@ func Load(ctx context.Context, path string) (*Bundle, error) {
// MustLoad returns a bundle configuration. // MustLoad returns a bundle configuration.
// It returns an error if a bundle was not found or could not be loaded. // It returns an error if a bundle was not found or could not be loaded.
func MustLoad(ctx context.Context) (*Bundle, error) { func MustLoad(ctx context.Context) (*Bundle, error) {
root, err := mustGetRoot() root, err := mustGetRoot(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -94,7 +97,7 @@ func MustLoad(ctx context.Context) (*Bundle, error) {
// It returns an error if a bundle was found but could not be loaded. // It returns an error if a bundle was found but could not be loaded.
// It returns a `nil` bundle if a bundle was not found. // It returns a `nil` bundle if a bundle was not found.
func TryLoad(ctx context.Context) (*Bundle, error) { func TryLoad(ctx context.Context) (*Bundle, error) {
root, err := tryGetRoot() root, err := tryGetRoot(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -120,13 +123,12 @@ func (b *Bundle) WorkspaceClient() *databricks.WorkspaceClient {
// CacheDir returns directory to use for temporary files for this bundle. // CacheDir returns directory to use for temporary files for this bundle.
// Scoped to the bundle's target. // Scoped to the bundle's target.
func (b *Bundle) CacheDir(paths ...string) (string, error) { func (b *Bundle) CacheDir(ctx context.Context, paths ...string) (string, error) {
if b.Config.Bundle.Target == "" { if b.Config.Bundle.Target == "" {
panic("target not set") panic("target not set")
} }
cacheDirName, exists := os.LookupEnv("DATABRICKS_BUNDLE_TMP") cacheDirName, exists := env.TempDir(ctx)
if !exists || cacheDirName == "" { if !exists || cacheDirName == "" {
cacheDirName = filepath.Join( cacheDirName = filepath.Join(
// Anchor at bundle root directory. // Anchor at bundle root directory.
@ -159,8 +161,8 @@ func (b *Bundle) CacheDir(paths ...string) (string, error) {
// This directory is used to store and automaticaly sync internal bundle files, such as, f.e // This directory is used to store and automaticaly sync internal bundle files, such as, f.e
// notebook trampoline files for Python wheel and etc. // notebook trampoline files for Python wheel and etc.
func (b *Bundle) InternalDir() (string, error) { func (b *Bundle) InternalDir(ctx context.Context) (string, error) {
cacheDir, err := b.CacheDir() cacheDir, err := b.CacheDir(ctx)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -177,8 +179,8 @@ func (b *Bundle) InternalDir() (string, error) {
// GetSyncIncludePatterns returns a list of user defined includes // GetSyncIncludePatterns returns a list of user defined includes
// And also adds InternalDir folder to include list for sync command // And also adds InternalDir folder to include list for sync command
// so this folder is always synced // so this folder is always synced
func (b *Bundle) GetSyncIncludePatterns() ([]string, error) { func (b *Bundle) GetSyncIncludePatterns(ctx context.Context) ([]string, error) {
internalDir, err := b.InternalDir() internalDir, err := b.InternalDir(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -6,6 +6,7 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/databricks/cli/bundle/env"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -23,12 +24,13 @@ func TestLoadExists(t *testing.T) {
} }
func TestBundleCacheDir(t *testing.T) { func TestBundleCacheDir(t *testing.T) {
ctx := context.Background()
projectDir := t.TempDir() projectDir := t.TempDir()
f1, err := os.Create(filepath.Join(projectDir, "databricks.yml")) f1, err := os.Create(filepath.Join(projectDir, "databricks.yml"))
require.NoError(t, err) require.NoError(t, err)
f1.Close() f1.Close()
bundle, err := Load(context.Background(), projectDir) bundle, err := Load(ctx, projectDir)
require.NoError(t, err) require.NoError(t, err)
// Artificially set target. // Artificially set target.
@ -38,7 +40,7 @@ func TestBundleCacheDir(t *testing.T) {
// unset env variable in case it's set // unset env variable in case it's set
t.Setenv("DATABRICKS_BUNDLE_TMP", "") t.Setenv("DATABRICKS_BUNDLE_TMP", "")
cacheDir, err := bundle.CacheDir() cacheDir, err := bundle.CacheDir(ctx)
// format is <CWD>/.databricks/bundle/<target> // format is <CWD>/.databricks/bundle/<target>
assert.NoError(t, err) assert.NoError(t, err)
@ -46,13 +48,14 @@ func TestBundleCacheDir(t *testing.T) {
} }
func TestBundleCacheDirOverride(t *testing.T) { func TestBundleCacheDirOverride(t *testing.T) {
ctx := context.Background()
projectDir := t.TempDir() projectDir := t.TempDir()
bundleTmpDir := t.TempDir() bundleTmpDir := t.TempDir()
f1, err := os.Create(filepath.Join(projectDir, "databricks.yml")) f1, err := os.Create(filepath.Join(projectDir, "databricks.yml"))
require.NoError(t, err) require.NoError(t, err)
f1.Close() f1.Close()
bundle, err := Load(context.Background(), projectDir) bundle, err := Load(ctx, projectDir)
require.NoError(t, err) require.NoError(t, err)
// Artificially set target. // Artificially set target.
@ -62,7 +65,7 @@ func TestBundleCacheDirOverride(t *testing.T) {
// now we expect to use 'bundleTmpDir' instead of CWD/.databricks/bundle // now we expect to use 'bundleTmpDir' instead of CWD/.databricks/bundle
t.Setenv("DATABRICKS_BUNDLE_TMP", bundleTmpDir) t.Setenv("DATABRICKS_BUNDLE_TMP", bundleTmpDir)
cacheDir, err := bundle.CacheDir() cacheDir, err := bundle.CacheDir(ctx)
// format is <DATABRICKS_BUNDLE_TMP>/<target> // format is <DATABRICKS_BUNDLE_TMP>/<target>
assert.NoError(t, err) assert.NoError(t, err)
@ -70,14 +73,14 @@ func TestBundleCacheDirOverride(t *testing.T) {
} }
func TestBundleMustLoadSuccess(t *testing.T) { func TestBundleMustLoadSuccess(t *testing.T) {
t.Setenv(envBundleRoot, "./tests/basic") t.Setenv(env.RootVariable, "./tests/basic")
b, err := MustLoad(context.Background()) b, err := MustLoad(context.Background())
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "tests/basic", filepath.ToSlash(b.Config.Path)) assert.Equal(t, "tests/basic", filepath.ToSlash(b.Config.Path))
} }
func TestBundleMustLoadFailureWithEnv(t *testing.T) { func TestBundleMustLoadFailureWithEnv(t *testing.T) {
t.Setenv(envBundleRoot, "./tests/doesntexist") t.Setenv(env.RootVariable, "./tests/doesntexist")
_, err := MustLoad(context.Background()) _, err := MustLoad(context.Background())
require.Error(t, err, "not a directory") require.Error(t, err, "not a directory")
} }
@ -89,14 +92,14 @@ func TestBundleMustLoadFailureIfNotFound(t *testing.T) {
} }
func TestBundleTryLoadSuccess(t *testing.T) { func TestBundleTryLoadSuccess(t *testing.T) {
t.Setenv(envBundleRoot, "./tests/basic") t.Setenv(env.RootVariable, "./tests/basic")
b, err := TryLoad(context.Background()) b, err := TryLoad(context.Background())
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "tests/basic", filepath.ToSlash(b.Config.Path)) assert.Equal(t, "tests/basic", filepath.ToSlash(b.Config.Path))
} }
func TestBundleTryLoadFailureWithEnv(t *testing.T) { func TestBundleTryLoadFailureWithEnv(t *testing.T) {
t.Setenv(envBundleRoot, "./tests/doesntexist") t.Setenv(env.RootVariable, "./tests/doesntexist")
_, err := TryLoad(context.Background()) _, err := TryLoad(context.Background())
require.Error(t, err, "not a directory") require.Error(t, err, "not a directory")
} }

View File

@ -78,9 +78,13 @@ func (a *Artifact) NormalisePaths() {
remotePath := path.Join(wsfsBase, f.RemotePath) remotePath := path.Join(wsfsBase, f.RemotePath)
for i := range f.Libraries { for i := range f.Libraries {
lib := f.Libraries[i] lib := f.Libraries[i]
switch a.Type { if lib.Whl != "" {
case ArtifactPythonWheel:
lib.Whl = remotePath lib.Whl = remotePath
continue
}
if lib.Jar != "" {
lib.Jar = remotePath
continue
} }
} }

View File

@ -184,7 +184,7 @@ func (a *accumulator) Resolve(path string, seenPaths []string, fns ...LookupFunc
// fetch the string node to resolve // fetch the string node to resolve
field, ok := a.strings[path] field, ok := a.strings[path]
if !ok { if !ok {
return fmt.Errorf("could not resolve reference %s", path) return fmt.Errorf("no value found for interpolation reference: ${%s}", path)
} }
// return early if the string field has no variables to interpolate // return early if the string field has no variables to interpolate

View File

@ -247,5 +247,5 @@ func TestInterpolationInvalidVariableReference(t *testing.T) {
} }
err := expand(&config) err := expand(&config)
assert.ErrorContains(t, err, "could not resolve reference vars.foo") assert.ErrorContains(t, err, "no value found for interpolation reference: ${vars.foo}")
} }

View File

@ -3,11 +3,11 @@ package mutator
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources" "github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/cli/libs/env"
) )
type overrideCompute struct{} type overrideCompute struct{}
@ -39,8 +39,8 @@ func (m *overrideCompute) Apply(ctx context.Context, b *bundle.Bundle) error {
} }
return nil return nil
} }
if os.Getenv("DATABRICKS_CLUSTER_ID") != "" { if v := env.Get(ctx, "DATABRICKS_CLUSTER_ID"); v != "" {
b.Config.Bundle.ComputeID = os.Getenv("DATABRICKS_CLUSTER_ID") b.Config.Bundle.ComputeID = v
} }
if b.Config.Bundle.ComputeID == "" { if b.Config.Bundle.ComputeID == "" {

View File

@ -21,6 +21,10 @@ func (m *populateCurrentUser) Name() string {
} }
func (m *populateCurrentUser) Apply(ctx context.Context, b *bundle.Bundle) error { func (m *populateCurrentUser) Apply(ctx context.Context, b *bundle.Bundle) error {
if b.Config.Workspace.CurrentUser != nil {
return nil
}
w := b.WorkspaceClient() w := b.WorkspaceClient()
me, err := w.CurrentUser.Me(ctx) me, err := w.CurrentUser.Me(ctx)
if err != nil { if err != nil {

View File

@ -10,11 +10,12 @@ import (
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/env"
) )
// Get extra include paths from environment variable // Get extra include paths from environment variable
func GetExtraIncludePaths() []string { func getExtraIncludePaths(ctx context.Context) []string {
value, exists := os.LookupEnv(bundle.ExtraIncludePathsKey) value, exists := env.Includes(ctx)
if !exists { if !exists {
return nil return nil
} }
@ -48,7 +49,7 @@ func (m *processRootIncludes) Apply(ctx context.Context, b *bundle.Bundle) error
var files []string var files []string
// Converts extra include paths from environment variable to relative paths // Converts extra include paths from environment variable to relative paths
for _, extraIncludePath := range GetExtraIncludePaths() { for _, extraIncludePath := range getExtraIncludePaths(ctx) {
if filepath.IsAbs(extraIncludePath) { if filepath.IsAbs(extraIncludePath) {
rel, err := filepath.Rel(b.Config.Path, extraIncludePath) rel, err := filepath.Rel(b.Config.Path, extraIncludePath)
if err != nil { if err != nil {

View File

@ -2,16 +2,17 @@ package mutator_test
import ( import (
"context" "context"
"fmt"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings"
"testing" "testing"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/mutator" "github.com/databricks/cli/bundle/config/mutator"
"github.com/databricks/cli/bundle/env"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -129,10 +130,7 @@ func TestProcessRootIncludesExtrasFromEnvVar(t *testing.T) {
rootPath := t.TempDir() rootPath := t.TempDir()
testYamlName := "extra_include_path.yml" testYamlName := "extra_include_path.yml"
touch(t, rootPath, testYamlName) touch(t, rootPath, testYamlName)
os.Setenv(bundle.ExtraIncludePathsKey, path.Join(rootPath, testYamlName)) t.Setenv(env.IncludesVariable, path.Join(rootPath, testYamlName))
t.Cleanup(func() {
os.Unsetenv(bundle.ExtraIncludePathsKey)
})
bundle := &bundle.Bundle{ bundle := &bundle.Bundle{
Config: config.Root{ Config: config.Root{
@ -149,7 +147,13 @@ func TestProcessRootIncludesDedupExtrasFromEnvVar(t *testing.T) {
rootPath := t.TempDir() rootPath := t.TempDir()
testYamlName := "extra_include_path.yml" testYamlName := "extra_include_path.yml"
touch(t, rootPath, testYamlName) touch(t, rootPath, testYamlName)
t.Setenv(bundle.ExtraIncludePathsKey, fmt.Sprintf("%s%s%s", path.Join(rootPath, testYamlName), string(os.PathListSeparator), path.Join(rootPath, testYamlName))) t.Setenv(env.IncludesVariable, strings.Join(
[]string{
path.Join(rootPath, testYamlName),
path.Join(rootPath, testYamlName),
},
string(os.PathListSeparator),
))
bundle := &bundle.Bundle{ bundle := &bundle.Bundle{
Config: config.Root{ Config: config.Root{

View File

@ -77,6 +77,12 @@ func transformDevelopmentMode(b *bundle.Bundle) error {
r.Experiments[i].Tags = append(r.Experiments[i].Tags, ml.ExperimentTag{Key: "dev", Value: b.Config.Workspace.CurrentUser.DisplayName}) r.Experiments[i].Tags = append(r.Experiments[i].Tags, ml.ExperimentTag{Key: "dev", Value: b.Config.Workspace.CurrentUser.DisplayName})
} }
for i := range r.ModelServingEndpoints {
prefix = "dev_" + b.Config.Workspace.CurrentUser.ShortName + "_"
r.ModelServingEndpoints[i].Name = prefix + r.ModelServingEndpoints[i].Name
// (model serving doesn't yet support tags)
}
return nil return nil
} }

View File

@ -13,6 +13,7 @@ import (
"github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/databricks/databricks-sdk-go/service/ml" "github.com/databricks/databricks-sdk-go/service/ml"
"github.com/databricks/databricks-sdk-go/service/pipelines" "github.com/databricks/databricks-sdk-go/service/pipelines"
"github.com/databricks/databricks-sdk-go/service/serving"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -53,6 +54,9 @@ func mockBundle(mode config.Mode) *bundle.Bundle {
Models: map[string]*resources.MlflowModel{ Models: map[string]*resources.MlflowModel{
"model1": {Model: &ml.Model{Name: "model1"}}, "model1": {Model: &ml.Model{Name: "model1"}},
}, },
ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{
"servingendpoint1": {CreateServingEndpoint: &serving.CreateServingEndpoint{Name: "servingendpoint1"}},
},
}, },
}, },
} }
@ -69,6 +73,7 @@ func TestProcessTargetModeDevelopment(t *testing.T) {
assert.Equal(t, "/Users/lennart.kats@databricks.com/[dev lennart] experiment1", bundle.Config.Resources.Experiments["experiment1"].Name) assert.Equal(t, "/Users/lennart.kats@databricks.com/[dev lennart] experiment1", bundle.Config.Resources.Experiments["experiment1"].Name)
assert.Equal(t, "[dev lennart] experiment2", bundle.Config.Resources.Experiments["experiment2"].Name) assert.Equal(t, "[dev lennart] experiment2", bundle.Config.Resources.Experiments["experiment2"].Name)
assert.Equal(t, "[dev lennart] model1", bundle.Config.Resources.Models["model1"].Name) assert.Equal(t, "[dev lennart] model1", bundle.Config.Resources.Models["model1"].Name)
assert.Equal(t, "dev_lennart_servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
assert.Equal(t, "dev", bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags[0].Key) assert.Equal(t, "dev", bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags[0].Key)
assert.True(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development) assert.True(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development)
} }
@ -82,6 +87,7 @@ func TestProcessTargetModeDefault(t *testing.T) {
assert.Equal(t, "job1", bundle.Config.Resources.Jobs["job1"].Name) assert.Equal(t, "job1", bundle.Config.Resources.Jobs["job1"].Name)
assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name) assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name)
assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development) assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development)
assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
} }
func TestProcessTargetModeProduction(t *testing.T) { func TestProcessTargetModeProduction(t *testing.T) {
@ -109,6 +115,7 @@ func TestProcessTargetModeProduction(t *testing.T) {
bundle.Config.Resources.Experiments["experiment1"].Permissions = permissions bundle.Config.Resources.Experiments["experiment1"].Permissions = permissions
bundle.Config.Resources.Experiments["experiment2"].Permissions = permissions bundle.Config.Resources.Experiments["experiment2"].Permissions = permissions
bundle.Config.Resources.Models["model1"].Permissions = permissions bundle.Config.Resources.Models["model1"].Permissions = permissions
bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Permissions = permissions
err = validateProductionMode(context.Background(), bundle, false) err = validateProductionMode(context.Background(), bundle, false)
require.NoError(t, err) require.NoError(t, err)
@ -116,6 +123,7 @@ func TestProcessTargetModeProduction(t *testing.T) {
assert.Equal(t, "job1", bundle.Config.Resources.Jobs["job1"].Name) assert.Equal(t, "job1", bundle.Config.Resources.Jobs["job1"].Name)
assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name) assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name)
assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development) assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development)
assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
} }
func TestProcessTargetModeProductionOkForPrincipal(t *testing.T) { func TestProcessTargetModeProductionOkForPrincipal(t *testing.T) {

View File

@ -3,8 +3,10 @@ package mutator
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"golang.org/x/exp/maps"
) )
type selectTarget struct { type selectTarget struct {
@ -30,7 +32,7 @@ func (m *selectTarget) Apply(_ context.Context, b *bundle.Bundle) error {
// Get specified target // Get specified target
target, ok := b.Config.Targets[m.name] target, ok := b.Config.Targets[m.name]
if !ok { if !ok {
return fmt.Errorf("%s: no such target", m.name) return fmt.Errorf("%s: no such target. Available targets: %s", m.name, strings.Join(maps.Keys(b.Config.Targets), ", "))
} }
// Merge specified target into root configuration structure. // Merge specified target into root configuration structure.

View File

@ -3,10 +3,10 @@ package mutator
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config/variable" "github.com/databricks/cli/bundle/config/variable"
"github.com/databricks/cli/libs/env"
) )
const bundleVarPrefix = "BUNDLE_VAR_" const bundleVarPrefix = "BUNDLE_VAR_"
@ -21,7 +21,7 @@ func (m *setVariables) Name() string {
return "SetVariables" return "SetVariables"
} }
func setVariable(v *variable.Variable, name string) error { func setVariable(ctx context.Context, v *variable.Variable, name string) error {
// case: variable already has value initialized, so skip // case: variable already has value initialized, so skip
if v.HasValue() { if v.HasValue() {
return nil return nil
@ -29,7 +29,7 @@ func setVariable(v *variable.Variable, name string) error {
// case: read and set variable value from process environment // case: read and set variable value from process environment
envVarName := bundleVarPrefix + name envVarName := bundleVarPrefix + name
if val, ok := os.LookupEnv(envVarName); ok { if val, ok := env.Lookup(ctx, envVarName); ok {
err := v.Set(val) err := v.Set(val)
if err != nil { if err != nil {
return fmt.Errorf(`failed to assign value "%s" to variable %s from environment variable %s with error: %w`, val, name, envVarName, err) return fmt.Errorf(`failed to assign value "%s" to variable %s from environment variable %s with error: %w`, val, name, envVarName, err)
@ -54,7 +54,7 @@ func setVariable(v *variable.Variable, name string) error {
func (m *setVariables) Apply(ctx context.Context, b *bundle.Bundle) error { func (m *setVariables) Apply(ctx context.Context, b *bundle.Bundle) error {
for name, variable := range b.Config.Variables { for name, variable := range b.Config.Variables {
err := setVariable(variable, name) err := setVariable(ctx, variable, name)
if err != nil { if err != nil {
return err return err
} }

View File

@ -21,7 +21,7 @@ func TestSetVariableFromProcessEnvVar(t *testing.T) {
// set value for variable as an environment variable // set value for variable as an environment variable
t.Setenv("BUNDLE_VAR_foo", "process-env") t.Setenv("BUNDLE_VAR_foo", "process-env")
err := setVariable(&variable, "foo") err := setVariable(context.Background(), &variable, "foo")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, *variable.Value, "process-env") assert.Equal(t, *variable.Value, "process-env")
} }
@ -33,7 +33,7 @@ func TestSetVariableUsingDefaultValue(t *testing.T) {
Default: &defaultVal, Default: &defaultVal,
} }
err := setVariable(&variable, "foo") err := setVariable(context.Background(), &variable, "foo")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, *variable.Value, "default") assert.Equal(t, *variable.Value, "default")
} }
@ -49,7 +49,7 @@ func TestSetVariableWhenAlreadyAValueIsAssigned(t *testing.T) {
// since a value is already assigned to the variable, it would not be overridden // since a value is already assigned to the variable, it would not be overridden
// by the default value // by the default value
err := setVariable(&variable, "foo") err := setVariable(context.Background(), &variable, "foo")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, *variable.Value, "assigned-value") assert.Equal(t, *variable.Value, "assigned-value")
} }
@ -68,7 +68,7 @@ func TestSetVariableEnvVarValueDoesNotOverridePresetValue(t *testing.T) {
// since a value is already assigned to the variable, it would not be overridden // since a value is already assigned to the variable, it would not be overridden
// by the value from environment // by the value from environment
err := setVariable(&variable, "foo") err := setVariable(context.Background(), &variable, "foo")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, *variable.Value, "assigned-value") assert.Equal(t, *variable.Value, "assigned-value")
} }
@ -79,7 +79,7 @@ func TestSetVariablesErrorsIfAValueCouldNotBeResolved(t *testing.T) {
} }
// fails because we could not resolve a value for the variable // fails because we could not resolve a value for the variable
err := setVariable(&variable, "foo") err := setVariable(context.Background(), &variable, "foo")
assert.ErrorContains(t, err, "no value assigned to required variable foo. Assignment can be done through the \"--var\" flag or by setting the BUNDLE_VAR_foo environment variable") assert.ErrorContains(t, err, "no value assigned to required variable foo. Assignment can be done through the \"--var\" flag or by setting the BUNDLE_VAR_foo environment variable")
} }

View File

@ -38,7 +38,7 @@ func (m *trampoline) Name() string {
func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error { func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error {
tasks := m.functions.GetTasks(b) tasks := m.functions.GetTasks(b)
for _, task := range tasks { for _, task := range tasks {
err := m.generateNotebookWrapper(b, task) err := m.generateNotebookWrapper(ctx, b, task)
if err != nil { if err != nil {
return err return err
} }
@ -46,8 +46,8 @@ func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error {
return nil return nil
} }
func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task jobs_utils.TaskWithJobKey) error { func (m *trampoline) generateNotebookWrapper(ctx context.Context, b *bundle.Bundle, task jobs_utils.TaskWithJobKey) error {
internalDir, err := b.InternalDir() internalDir, err := b.InternalDir(ctx)
if err != nil { if err != nil {
return err return err
} }

View File

@ -11,6 +11,7 @@ import (
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/paths" "github.com/databricks/cli/bundle/config/paths"
"github.com/databricks/cli/bundle/config/resources" "github.com/databricks/cli/bundle/config/resources"
jobs_utils "github.com/databricks/cli/libs/jobs"
"github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -18,10 +19,10 @@ import (
type functions struct { type functions struct {
} }
func (f *functions) GetTasks(b *bundle.Bundle) []TaskWithJobKey { func (f *functions) GetTasks(b *bundle.Bundle) []jobs_utils.TaskWithJobKey {
tasks := make([]TaskWithJobKey, 0) tasks := make([]jobs_utils.TaskWithJobKey, 0)
for k := range b.Config.Resources.Jobs["test"].Tasks { for k := range b.Config.Resources.Jobs["test"].Tasks {
tasks = append(tasks, TaskWithJobKey{ tasks = append(tasks, jobs_utils.TaskWithJobKey{
JobKey: "test", JobKey: "test",
Task: &b.Config.Resources.Jobs["test"].Tasks[k], Task: &b.Config.Resources.Jobs["test"].Tasks[k],
}) })
@ -88,7 +89,7 @@ func TestGenerateTrampoline(t *testing.T) {
err := bundle.Apply(ctx, b, trampoline) err := bundle.Apply(ctx, b, trampoline)
require.NoError(t, err) require.NoError(t, err)
dir, err := b.InternalDir() dir, err := b.InternalDir(ctx)
require.NoError(t, err) require.NoError(t, err)
filename := filepath.Join(dir, "notebook_test_trampoline_test_to_trampoline.py") filename := filepath.Join(dir, "notebook_test_trampoline_test_to_trampoline.py")

View File

@ -162,7 +162,7 @@ func TestTranslatePaths(t *testing.T) {
MainClassName: "HelloWorldRemote", MainClassName: "HelloWorldRemote",
}, },
Libraries: []compute.Library{ Libraries: []compute.Library{
{Jar: "dbfs:///bundle/dist/task_remote.jar"}, {Jar: "dbfs:/bundle/dist/task_remote.jar"},
}, },
}, },
}, },
@ -243,7 +243,7 @@ func TestTranslatePaths(t *testing.T) {
) )
assert.Equal( assert.Equal(
t, t,
"dbfs:///bundle/dist/task_remote.jar", "dbfs:/bundle/dist/task_remote.jar",
bundle.Config.Resources.Jobs["job"].Tasks[6].Libraries[0].Jar, bundle.Config.Resources.Jobs["job"].Tasks[6].Libraries[0].Jar,
) )

View File

@ -13,6 +13,7 @@ type Resources struct {
Models map[string]*resources.MlflowModel `json:"models,omitempty"` Models map[string]*resources.MlflowModel `json:"models,omitempty"`
Experiments map[string]*resources.MlflowExperiment `json:"experiments,omitempty"` Experiments map[string]*resources.MlflowExperiment `json:"experiments,omitempty"`
ModelServingEndpoints map[string]*resources.ModelServingEndpoint `json:"model_serving_endpoints,omitempty"`
} }
type UniqueResourceIdTracker struct { type UniqueResourceIdTracker struct {
@ -93,6 +94,19 @@ func (r *Resources) VerifyUniqueResourceIdentifiers() (*UniqueResourceIdTracker,
tracker.Type[k] = "mlflow_experiment" tracker.Type[k] = "mlflow_experiment"
tracker.ConfigPath[k] = r.Experiments[k].ConfigFilePath tracker.ConfigPath[k] = r.Experiments[k].ConfigFilePath
} }
for k := range r.ModelServingEndpoints {
if _, ok := tracker.Type[k]; ok {
return tracker, fmt.Errorf("multiple resources named %s (%s at %s, %s at %s)",
k,
tracker.Type[k],
tracker.ConfigPath[k],
"model_serving_endpoint",
r.ModelServingEndpoints[k].ConfigFilePath,
)
}
tracker.Type[k] = "model_serving_endpoint"
tracker.ConfigPath[k] = r.ModelServingEndpoints[k].ConfigFilePath
}
return tracker, nil return tracker, nil
} }
@ -112,6 +126,9 @@ func (r *Resources) SetConfigFilePath(path string) {
for _, e := range r.Experiments { for _, e := range r.Experiments {
e.ConfigFilePath = path e.ConfigFilePath = path
} }
for _, e := range r.ModelServingEndpoints {
e.ConfigFilePath = path
}
} }
// MergeJobClusters iterates over all jobs and merges their job clusters. // MergeJobClusters iterates over all jobs and merges their job clusters.

View File

@ -0,0 +1,24 @@
package resources
import (
"github.com/databricks/cli/bundle/config/paths"
"github.com/databricks/databricks-sdk-go/service/serving"
)
type ModelServingEndpoint struct {
// This represents the input args for terraform, and will get converted
// to a HCL representation for CRUD
*serving.CreateServingEndpoint
// This represents the id (ie serving_endpoint_id) that can be used
// as a reference in other resources. This value is returned by terraform.
ID string
// Local path where the bundle is defined. All bundle resources include
// this for interpolation purposes.
paths.Paths
// This is a resource agnostic implementation of permissions for ACLs.
// Implementation could be different based on the resource type.
Permissions []Permission `json:"permissions,omitempty"`
}

View File

@ -52,7 +52,7 @@ type Root struct {
// Bundle contains details about this bundle, such as its name, // Bundle contains details about this bundle, such as its name,
// version of the spec (TODO), default cluster, default warehouse, etc. // version of the spec (TODO), default cluster, default warehouse, etc.
Bundle Bundle `json:"bundle"` Bundle Bundle `json:"bundle,omitempty"`
// Include specifies a list of patterns of file names to load and // Include specifies a list of patterns of file names to load and
// merge into the this configuration. Only includes defined in the root // merge into the this configuration. Only includes defined in the root
@ -80,7 +80,7 @@ type Root struct {
Environments map[string]*Target `json:"environments,omitempty"` Environments map[string]*Target `json:"environments,omitempty"`
// Sync section specifies options for files synchronization // Sync section specifies options for files synchronization
Sync Sync `json:"sync"` Sync Sync `json:"sync,omitempty"`
// RunAs section allows to define an execution identity for jobs and pipelines runs // RunAs section allows to define an execution identity for jobs and pipelines runs
RunAs *jobs.JobRunAs `json:"run_as,omitempty"` RunAs *jobs.JobRunAs `json:"run_as,omitempty"`

View File

@ -9,12 +9,12 @@ import (
) )
func getSync(ctx context.Context, b *bundle.Bundle) (*sync.Sync, error) { func getSync(ctx context.Context, b *bundle.Bundle) (*sync.Sync, error) {
cacheDir, err := b.CacheDir() cacheDir, err := b.CacheDir(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot get bundle cache directory: %w", err) return nil, fmt.Errorf("cannot get bundle cache directory: %w", err)
} }
includes, err := b.GetSyncIncludePatterns() includes, err := b.GetSyncIncludePatterns(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot get list of sync includes: %w", err) return nil, fmt.Errorf("cannot get list of sync includes: %w", err)
} }

View File

@ -16,6 +16,10 @@ func (w *apply) Name() string {
} }
func (w *apply) Apply(ctx context.Context, b *bundle.Bundle) error { func (w *apply) Apply(ctx context.Context, b *bundle.Bundle) error {
if b.TerraformHasNoResources {
cmdio.LogString(ctx, "Note: there are no resources to deploy for this bundle")
return nil
}
tf := b.Terraform tf := b.Terraform
if tf == nil { if tf == nil {
return fmt.Errorf("terraform not initialized") return fmt.Errorf("terraform not initialized")

View File

@ -49,12 +49,14 @@ func convPermission(ac resources.Permission) schema.ResourcePermissionsAccessCon
// //
// NOTE: THIS IS CURRENTLY A HACK. WE NEED A BETTER WAY TO // NOTE: THIS IS CURRENTLY A HACK. WE NEED A BETTER WAY TO
// CONVERT TO/FROM TERRAFORM COMPATIBLE FORMAT. // CONVERT TO/FROM TERRAFORM COMPATIBLE FORMAT.
func BundleToTerraform(config *config.Root) *schema.Root { func BundleToTerraform(config *config.Root) (*schema.Root, bool) {
tfroot := schema.NewRoot() tfroot := schema.NewRoot()
tfroot.Provider = schema.NewProviders() tfroot.Provider = schema.NewProviders()
tfroot.Resource = schema.NewResources() tfroot.Resource = schema.NewResources()
noResources := true
for k, src := range config.Resources.Jobs { for k, src := range config.Resources.Jobs {
noResources = false
var dst schema.ResourceJob var dst schema.ResourceJob
conv(src, &dst) conv(src, &dst)
@ -88,6 +90,12 @@ func BundleToTerraform(config *config.Root) *schema.Root {
Tag: git.GitTag, Tag: git.GitTag,
} }
} }
for _, v := range src.Parameters {
var t schema.ResourceJobParameter
conv(v, &t)
dst.Parameter = append(dst.Parameter, t)
}
} }
tfroot.Resource.Job[k] = &dst tfroot.Resource.Job[k] = &dst
@ -100,6 +108,7 @@ func BundleToTerraform(config *config.Root) *schema.Root {
} }
for k, src := range config.Resources.Pipelines { for k, src := range config.Resources.Pipelines {
noResources = false
var dst schema.ResourcePipeline var dst schema.ResourcePipeline
conv(src, &dst) conv(src, &dst)
@ -127,6 +136,7 @@ func BundleToTerraform(config *config.Root) *schema.Root {
} }
for k, src := range config.Resources.Models { for k, src := range config.Resources.Models {
noResources = false
var dst schema.ResourceMlflowModel var dst schema.ResourceMlflowModel
conv(src, &dst) conv(src, &dst)
tfroot.Resource.MlflowModel[k] = &dst tfroot.Resource.MlflowModel[k] = &dst
@ -139,6 +149,7 @@ func BundleToTerraform(config *config.Root) *schema.Root {
} }
for k, src := range config.Resources.Experiments { for k, src := range config.Resources.Experiments {
noResources = false
var dst schema.ResourceMlflowExperiment var dst schema.ResourceMlflowExperiment
conv(src, &dst) conv(src, &dst)
tfroot.Resource.MlflowExperiment[k] = &dst tfroot.Resource.MlflowExperiment[k] = &dst
@ -150,7 +161,20 @@ func BundleToTerraform(config *config.Root) *schema.Root {
} }
} }
return tfroot for k, src := range config.Resources.ModelServingEndpoints {
noResources = false
var dst schema.ResourceModelServing
conv(src, &dst)
tfroot.Resource.ModelServing[k] = &dst
// Configure permissions for this resource.
if rp := convPermissions(src.Permissions); rp != nil {
rp.ServingEndpointId = fmt.Sprintf("${databricks_model_serving.%s.serving_endpoint_id}", k)
tfroot.Resource.Permissions["model_serving_"+k] = rp
}
}
return tfroot, noResources
} }
func TerraformToBundle(state *tfjson.State, config *config.Root) error { func TerraformToBundle(state *tfjson.State, config *config.Root) error {
@ -185,6 +209,12 @@ func TerraformToBundle(state *tfjson.State, config *config.Root) error {
cur := config.Resources.Experiments[resource.Name] cur := config.Resources.Experiments[resource.Name]
conv(tmp, &cur) conv(tmp, &cur)
config.Resources.Experiments[resource.Name] = cur config.Resources.Experiments[resource.Name] = cur
case "databricks_model_serving":
var tmp schema.ResourceModelServing
conv(resource.AttributeValues, &tmp)
cur := config.Resources.ModelServingEndpoints[resource.Name]
conv(tmp, &cur)
config.Resources.ModelServingEndpoints[resource.Name] = cur
case "databricks_permissions": case "databricks_permissions":
// Ignore; no need to pull these back into the configuration. // Ignore; no need to pull these back into the configuration.
default: default:

View File

@ -9,6 +9,7 @@ import (
"github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/databricks/databricks-sdk-go/service/ml" "github.com/databricks/databricks-sdk-go/service/ml"
"github.com/databricks/databricks-sdk-go/service/pipelines" "github.com/databricks/databricks-sdk-go/service/pipelines"
"github.com/databricks/databricks-sdk-go/service/serving"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -29,6 +30,16 @@ func TestConvertJob(t *testing.T) {
GitProvider: jobs.GitProviderGitHub, GitProvider: jobs.GitProviderGitHub,
GitUrl: "https://github.com/foo/bar", GitUrl: "https://github.com/foo/bar",
}, },
Parameters: []jobs.JobParameterDefinition{
{
Name: "param1",
Default: "default1",
},
{
Name: "param2",
Default: "default2",
},
},
}, },
} }
@ -40,10 +51,13 @@ func TestConvertJob(t *testing.T) {
}, },
} }
out := BundleToTerraform(&config) out, _ := BundleToTerraform(&config)
assert.Equal(t, "my job", out.Resource.Job["my_job"].Name) assert.Equal(t, "my job", out.Resource.Job["my_job"].Name)
assert.Len(t, out.Resource.Job["my_job"].JobCluster, 1) assert.Len(t, out.Resource.Job["my_job"].JobCluster, 1)
assert.Equal(t, "https://github.com/foo/bar", out.Resource.Job["my_job"].GitSource.Url) assert.Equal(t, "https://github.com/foo/bar", out.Resource.Job["my_job"].GitSource.Url)
assert.Len(t, out.Resource.Job["my_job"].Parameter, 2)
assert.Equal(t, "param1", out.Resource.Job["my_job"].Parameter[0].Name)
assert.Equal(t, "param2", out.Resource.Job["my_job"].Parameter[1].Name)
assert.Nil(t, out.Data) assert.Nil(t, out.Data)
} }
@ -65,7 +79,7 @@ func TestConvertJobPermissions(t *testing.T) {
}, },
} }
out := BundleToTerraform(&config) out, _ := BundleToTerraform(&config)
assert.NotEmpty(t, out.Resource.Permissions["job_my_job"].JobId) assert.NotEmpty(t, out.Resource.Permissions["job_my_job"].JobId)
assert.Len(t, out.Resource.Permissions["job_my_job"].AccessControl, 1) assert.Len(t, out.Resource.Permissions["job_my_job"].AccessControl, 1)
@ -101,7 +115,7 @@ func TestConvertJobTaskLibraries(t *testing.T) {
}, },
} }
out := BundleToTerraform(&config) out, _ := BundleToTerraform(&config)
assert.Equal(t, "my job", out.Resource.Job["my_job"].Name) assert.Equal(t, "my job", out.Resource.Job["my_job"].Name)
require.Len(t, out.Resource.Job["my_job"].Task, 1) require.Len(t, out.Resource.Job["my_job"].Task, 1)
require.Len(t, out.Resource.Job["my_job"].Task[0].Library, 1) require.Len(t, out.Resource.Job["my_job"].Task[0].Library, 1)
@ -135,7 +149,7 @@ func TestConvertPipeline(t *testing.T) {
}, },
} }
out := BundleToTerraform(&config) out, _ := BundleToTerraform(&config)
assert.Equal(t, "my pipeline", out.Resource.Pipeline["my_pipeline"].Name) assert.Equal(t, "my pipeline", out.Resource.Pipeline["my_pipeline"].Name)
assert.Len(t, out.Resource.Pipeline["my_pipeline"].Library, 2) assert.Len(t, out.Resource.Pipeline["my_pipeline"].Library, 2)
assert.Nil(t, out.Data) assert.Nil(t, out.Data)
@ -159,7 +173,7 @@ func TestConvertPipelinePermissions(t *testing.T) {
}, },
} }
out := BundleToTerraform(&config) out, _ := BundleToTerraform(&config)
assert.NotEmpty(t, out.Resource.Permissions["pipeline_my_pipeline"].PipelineId) assert.NotEmpty(t, out.Resource.Permissions["pipeline_my_pipeline"].PipelineId)
assert.Len(t, out.Resource.Permissions["pipeline_my_pipeline"].AccessControl, 1) assert.Len(t, out.Resource.Permissions["pipeline_my_pipeline"].AccessControl, 1)
@ -194,7 +208,7 @@ func TestConvertModel(t *testing.T) {
}, },
} }
out := BundleToTerraform(&config) out, _ := BundleToTerraform(&config)
assert.Equal(t, "name", out.Resource.MlflowModel["my_model"].Name) assert.Equal(t, "name", out.Resource.MlflowModel["my_model"].Name)
assert.Equal(t, "description", out.Resource.MlflowModel["my_model"].Description) assert.Equal(t, "description", out.Resource.MlflowModel["my_model"].Description)
assert.Len(t, out.Resource.MlflowModel["my_model"].Tags, 2) assert.Len(t, out.Resource.MlflowModel["my_model"].Tags, 2)
@ -223,7 +237,7 @@ func TestConvertModelPermissions(t *testing.T) {
}, },
} }
out := BundleToTerraform(&config) out, _ := BundleToTerraform(&config)
assert.NotEmpty(t, out.Resource.Permissions["mlflow_model_my_model"].RegisteredModelId) assert.NotEmpty(t, out.Resource.Permissions["mlflow_model_my_model"].RegisteredModelId)
assert.Len(t, out.Resource.Permissions["mlflow_model_my_model"].AccessControl, 1) assert.Len(t, out.Resource.Permissions["mlflow_model_my_model"].AccessControl, 1)
@ -247,7 +261,7 @@ func TestConvertExperiment(t *testing.T) {
}, },
} }
out := BundleToTerraform(&config) out, _ := BundleToTerraform(&config)
assert.Equal(t, "name", out.Resource.MlflowExperiment["my_experiment"].Name) assert.Equal(t, "name", out.Resource.MlflowExperiment["my_experiment"].Name)
assert.Nil(t, out.Data) assert.Nil(t, out.Data)
} }
@ -270,7 +284,7 @@ func TestConvertExperimentPermissions(t *testing.T) {
}, },
} }
out := BundleToTerraform(&config) out, _ := BundleToTerraform(&config)
assert.NotEmpty(t, out.Resource.Permissions["mlflow_experiment_my_experiment"].ExperimentId) assert.NotEmpty(t, out.Resource.Permissions["mlflow_experiment_my_experiment"].ExperimentId)
assert.Len(t, out.Resource.Permissions["mlflow_experiment_my_experiment"].AccessControl, 1) assert.Len(t, out.Resource.Permissions["mlflow_experiment_my_experiment"].AccessControl, 1)
@ -279,3 +293,76 @@ func TestConvertExperimentPermissions(t *testing.T) {
assert.Equal(t, "CAN_READ", p.PermissionLevel) assert.Equal(t, "CAN_READ", p.PermissionLevel)
} }
func TestConvertModelServing(t *testing.T) {
var src = resources.ModelServingEndpoint{
CreateServingEndpoint: &serving.CreateServingEndpoint{
Name: "name",
Config: serving.EndpointCoreConfigInput{
ServedModels: []serving.ServedModelInput{
{
ModelName: "model_name",
ModelVersion: "1",
ScaleToZeroEnabled: true,
WorkloadSize: "Small",
},
},
TrafficConfig: &serving.TrafficConfig{
Routes: []serving.Route{
{
ServedModelName: "model_name-1",
TrafficPercentage: 100,
},
},
},
},
},
}
var config = config.Root{
Resources: config.Resources{
ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{
"my_model_serving_endpoint": &src,
},
},
}
out, _ := BundleToTerraform(&config)
resource := out.Resource.ModelServing["my_model_serving_endpoint"]
assert.Equal(t, "name", resource.Name)
assert.Equal(t, "model_name", resource.Config.ServedModels[0].ModelName)
assert.Equal(t, "1", resource.Config.ServedModels[0].ModelVersion)
assert.Equal(t, true, resource.Config.ServedModels[0].ScaleToZeroEnabled)
assert.Equal(t, "Small", resource.Config.ServedModels[0].WorkloadSize)
assert.Equal(t, "model_name-1", resource.Config.TrafficConfig.Routes[0].ServedModelName)
assert.Equal(t, 100, resource.Config.TrafficConfig.Routes[0].TrafficPercentage)
assert.Nil(t, out.Data)
}
func TestConvertModelServingPermissions(t *testing.T) {
var src = resources.ModelServingEndpoint{
Permissions: []resources.Permission{
{
Level: "CAN_VIEW",
UserName: "jane@doe.com",
},
},
}
var config = config.Root{
Resources: config.Resources{
ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{
"my_model_serving_endpoint": &src,
},
},
}
out, _ := BundleToTerraform(&config)
assert.NotEmpty(t, out.Resource.Permissions["model_serving_my_model_serving_endpoint"].ServingEndpointId)
assert.Len(t, out.Resource.Permissions["model_serving_my_model_serving_endpoint"].AccessControl, 1)
p := out.Resource.Permissions["model_serving_my_model_serving_endpoint"].AccessControl[0]
assert.Equal(t, "jane@doe.com", p.UserName)
assert.Equal(t, "CAN_VIEW", p.PermissionLevel)
}

View File

@ -1,11 +1,13 @@
package terraform package terraform
import ( import (
"context"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
) )
// Dir returns the Terraform working directory for a given bundle. // Dir returns the Terraform working directory for a given bundle.
// The working directory is emphemeral and nested under the bundle's cache directory. // The working directory is emphemeral and nested under the bundle's cache directory.
func Dir(b *bundle.Bundle) (string, error) { func Dir(ctx context.Context, b *bundle.Bundle) (string, error) {
return b.CacheDir("terraform") return b.CacheDir(ctx, "terraform")
} }

View File

@ -8,9 +8,11 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
"time"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/log" "github.com/databricks/cli/libs/log"
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"github.com/hashicorp/hc-install/product" "github.com/hashicorp/hc-install/product"
@ -37,7 +39,7 @@ func (m *initialize) findExecPath(ctx context.Context, b *bundle.Bundle, tf *con
return tf.ExecPath, nil return tf.ExecPath, nil
} }
binDir, err := b.CacheDir("bin") binDir, err := b.CacheDir(context.Background(), "bin")
if err != nil { if err != nil {
return "", err return "", err
} }
@ -59,6 +61,7 @@ func (m *initialize) findExecPath(ctx context.Context, b *bundle.Bundle, tf *con
Product: product.Terraform, Product: product.Terraform,
Version: version.Must(version.NewVersion("1.5.5")), Version: version.Must(version.NewVersion("1.5.5")),
InstallDir: binDir, InstallDir: binDir,
Timeout: 1 * time.Minute,
} }
execPath, err = installer.Install(ctx) execPath, err = installer.Install(ctx)
if err != nil { if err != nil {
@ -71,17 +74,25 @@ func (m *initialize) findExecPath(ctx context.Context, b *bundle.Bundle, tf *con
} }
// This function inherits some environment variables for Terraform CLI. // This function inherits some environment variables for Terraform CLI.
func inheritEnvVars(env map[string]string) error { func inheritEnvVars(ctx context.Context, environ map[string]string) error {
// Include $HOME in set of environment variables to pass along. // Include $HOME in set of environment variables to pass along.
home, ok := os.LookupEnv("HOME") home, ok := env.Lookup(ctx, "HOME")
if ok { if ok {
env["HOME"] = home environ["HOME"] = home
}
// Include $PATH in set of environment variables to pass along.
// This is necessary to ensure that our Terraform provider can use the
// same auxiliary programs (e.g. `az`, or `gcloud`) as the CLI.
path, ok := env.Lookup(ctx, "PATH")
if ok {
environ["PATH"] = path
} }
// Include $TF_CLI_CONFIG_FILE to override terraform provider in development. // Include $TF_CLI_CONFIG_FILE to override terraform provider in development.
configFile, ok := os.LookupEnv("TF_CLI_CONFIG_FILE") configFile, ok := env.Lookup(ctx, "TF_CLI_CONFIG_FILE")
if ok { if ok {
env["TF_CLI_CONFIG_FILE"] = configFile environ["TF_CLI_CONFIG_FILE"] = configFile
} }
return nil return nil
@ -95,40 +106,40 @@ func inheritEnvVars(env map[string]string) error {
// the CLI and its dependencies do not have access to. // the CLI and its dependencies do not have access to.
// //
// see: os.TempDir for more context // see: os.TempDir for more context
func setTempDirEnvVars(env map[string]string, b *bundle.Bundle) error { func setTempDirEnvVars(ctx context.Context, environ map[string]string, b *bundle.Bundle) error {
switch runtime.GOOS { switch runtime.GOOS {
case "windows": case "windows":
if v, ok := os.LookupEnv("TMP"); ok { if v, ok := env.Lookup(ctx, "TMP"); ok {
env["TMP"] = v environ["TMP"] = v
} else if v, ok := os.LookupEnv("TEMP"); ok { } else if v, ok := env.Lookup(ctx, "TEMP"); ok {
env["TEMP"] = v environ["TEMP"] = v
} else if v, ok := os.LookupEnv("USERPROFILE"); ok { } else if v, ok := env.Lookup(ctx, "USERPROFILE"); ok {
env["USERPROFILE"] = v environ["USERPROFILE"] = v
} else { } else {
tmpDir, err := b.CacheDir("tmp") tmpDir, err := b.CacheDir(ctx, "tmp")
if err != nil { if err != nil {
return err return err
} }
env["TMP"] = tmpDir environ["TMP"] = tmpDir
} }
default: default:
// If TMPDIR is not set, we let the process fall back to its default value. // If TMPDIR is not set, we let the process fall back to its default value.
if v, ok := os.LookupEnv("TMPDIR"); ok { if v, ok := env.Lookup(ctx, "TMPDIR"); ok {
env["TMPDIR"] = v environ["TMPDIR"] = v
} }
} }
return nil return nil
} }
// This function passes through all proxy related environment variables. // This function passes through all proxy related environment variables.
func setProxyEnvVars(env map[string]string, b *bundle.Bundle) error { func setProxyEnvVars(ctx context.Context, environ map[string]string, b *bundle.Bundle) error {
for _, v := range []string{"http_proxy", "https_proxy", "no_proxy"} { for _, v := range []string{"http_proxy", "https_proxy", "no_proxy"} {
// The case (upper or lower) is notoriously inconsistent for tools on Unix systems. // The case (upper or lower) is notoriously inconsistent for tools on Unix systems.
// We therefore try to read both the upper and lower case versions of the variable. // We therefore try to read both the upper and lower case versions of the variable.
for _, v := range []string{strings.ToUpper(v), strings.ToLower(v)} { for _, v := range []string{strings.ToUpper(v), strings.ToLower(v)} {
if val, ok := os.LookupEnv(v); ok { if val, ok := env.Lookup(ctx, v); ok {
// Only set uppercase version of the variable. // Only set uppercase version of the variable.
env[strings.ToUpper(v)] = val environ[strings.ToUpper(v)] = val
} }
} }
} }
@ -147,7 +158,7 @@ func (m *initialize) Apply(ctx context.Context, b *bundle.Bundle) error {
return err return err
} }
workingDir, err := Dir(b) workingDir, err := Dir(ctx, b)
if err != nil { if err != nil {
return err return err
} }
@ -157,31 +168,31 @@ func (m *initialize) Apply(ctx context.Context, b *bundle.Bundle) error {
return err return err
} }
env, err := b.AuthEnv() environ, err := b.AuthEnv()
if err != nil { if err != nil {
return err return err
} }
err = inheritEnvVars(env) err = inheritEnvVars(ctx, environ)
if err != nil { if err != nil {
return err return err
} }
// Set the temporary directory environment variables // Set the temporary directory environment variables
err = setTempDirEnvVars(env, b) err = setTempDirEnvVars(ctx, environ, b)
if err != nil { if err != nil {
return err return err
} }
// Set the proxy related environment variables // Set the proxy related environment variables
err = setProxyEnvVars(env, b) err = setProxyEnvVars(ctx, environ, b)
if err != nil { if err != nil {
return err return err
} }
// Configure environment variables for auth for Terraform to use. // Configure environment variables for auth for Terraform to use.
log.Debugf(ctx, "Environment variables for Terraform: %s", strings.Join(maps.Keys(env), ", ")) log.Debugf(ctx, "Environment variables for Terraform: %s", strings.Join(maps.Keys(environ), ", "))
err = tf.SetEnv(env) err = tf.SetEnv(environ)
if err != nil { if err != nil {
return err return err
} }

View File

@ -68,7 +68,7 @@ func TestSetTempDirEnvVarsForUnixWithTmpDirSet(t *testing.T) {
// compute env // compute env
env := make(map[string]string, 0) env := make(map[string]string, 0)
err := setTempDirEnvVars(env, b) err := setTempDirEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
// Assert that we pass through TMPDIR. // Assert that we pass through TMPDIR.
@ -96,7 +96,7 @@ func TestSetTempDirEnvVarsForUnixWithTmpDirNotSet(t *testing.T) {
// compute env // compute env
env := make(map[string]string, 0) env := make(map[string]string, 0)
err := setTempDirEnvVars(env, b) err := setTempDirEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
// Assert that we don't pass through TMPDIR. // Assert that we don't pass through TMPDIR.
@ -124,7 +124,7 @@ func TestSetTempDirEnvVarsForWindowWithAllTmpDirEnvVarsSet(t *testing.T) {
// compute env // compute env
env := make(map[string]string, 0) env := make(map[string]string, 0)
err := setTempDirEnvVars(env, b) err := setTempDirEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
// assert that we pass through the highest priority env var value // assert that we pass through the highest priority env var value
@ -154,7 +154,7 @@ func TestSetTempDirEnvVarsForWindowWithUserProfileAndTempSet(t *testing.T) {
// compute env // compute env
env := make(map[string]string, 0) env := make(map[string]string, 0)
err := setTempDirEnvVars(env, b) err := setTempDirEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
// assert that we pass through the highest priority env var value // assert that we pass through the highest priority env var value
@ -184,7 +184,7 @@ func TestSetTempDirEnvVarsForWindowWithUserProfileSet(t *testing.T) {
// compute env // compute env
env := make(map[string]string, 0) env := make(map[string]string, 0)
err := setTempDirEnvVars(env, b) err := setTempDirEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
// assert that we pass through the user profile // assert that we pass through the user profile
@ -214,11 +214,11 @@ func TestSetTempDirEnvVarsForWindowsWithoutAnyTempDirEnvVarsSet(t *testing.T) {
// compute env // compute env
env := make(map[string]string, 0) env := make(map[string]string, 0)
err := setTempDirEnvVars(env, b) err := setTempDirEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
// assert TMP is set to b.CacheDir("tmp") // assert TMP is set to b.CacheDir("tmp")
tmpDir, err := b.CacheDir("tmp") tmpDir, err := b.CacheDir(context.Background(), "tmp")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, map[string]string{ assert.Equal(t, map[string]string{
"TMP": tmpDir, "TMP": tmpDir,
@ -248,7 +248,7 @@ func TestSetProxyEnvVars(t *testing.T) {
// No proxy env vars set. // No proxy env vars set.
clearEnv() clearEnv()
env := make(map[string]string, 0) env := make(map[string]string, 0)
err := setProxyEnvVars(env, b) err := setProxyEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, env, 0) assert.Len(t, env, 0)
@ -258,7 +258,7 @@ func TestSetProxyEnvVars(t *testing.T) {
t.Setenv("https_proxy", "foo") t.Setenv("https_proxy", "foo")
t.Setenv("no_proxy", "foo") t.Setenv("no_proxy", "foo")
env = make(map[string]string, 0) env = make(map[string]string, 0)
err = setProxyEnvVars(env, b) err = setProxyEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
assert.ElementsMatch(t, []string{"HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY"}, maps.Keys(env)) assert.ElementsMatch(t, []string{"HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY"}, maps.Keys(env))
@ -268,7 +268,7 @@ func TestSetProxyEnvVars(t *testing.T) {
t.Setenv("HTTPS_PROXY", "foo") t.Setenv("HTTPS_PROXY", "foo")
t.Setenv("NO_PROXY", "foo") t.Setenv("NO_PROXY", "foo")
env = make(map[string]string, 0) env = make(map[string]string, 0)
err = setProxyEnvVars(env, b) err = setProxyEnvVars(context.Background(), env, b)
require.NoError(t, err) require.NoError(t, err)
assert.ElementsMatch(t, []string{"HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY"}, maps.Keys(env)) assert.ElementsMatch(t, []string{"HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY"}, maps.Keys(env))
} }
@ -277,14 +277,16 @@ func TestInheritEnvVars(t *testing.T) {
env := map[string]string{} env := map[string]string{}
t.Setenv("HOME", "/home/testuser") t.Setenv("HOME", "/home/testuser")
t.Setenv("PATH", "/foo:/bar")
t.Setenv("TF_CLI_CONFIG_FILE", "/tmp/config.tfrc") t.Setenv("TF_CLI_CONFIG_FILE", "/tmp/config.tfrc")
err := inheritEnvVars(env) err := inheritEnvVars(context.Background(), env)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, map[string]string{ require.Equal(t, map[string]string{
"HOME": "/home/testuser", "HOME": "/home/testuser",
"PATH": "/foo:/bar",
"TF_CLI_CONFIG_FILE": "/tmp/config.tfrc", "TF_CLI_CONFIG_FILE": "/tmp/config.tfrc",
}, env) }, env)
} }

View File

@ -25,6 +25,9 @@ func interpolateTerraformResourceIdentifiers(path string, lookup map[string]stri
case "experiments": case "experiments":
path = strings.Join(append([]string{"databricks_mlflow_experiment"}, parts[2:]...), interpolation.Delimiter) path = strings.Join(append([]string{"databricks_mlflow_experiment"}, parts[2:]...), interpolation.Delimiter)
return fmt.Sprintf("${%s}", path), nil return fmt.Sprintf("${%s}", path), nil
case "model_serving_endpoints":
path = strings.Join(append([]string{"databricks_model_serving"}, parts[2:]...), interpolation.Delimiter)
return fmt.Sprintf("${%s}", path), nil
default: default:
panic("TODO: " + parts[1]) panic("TODO: " + parts[1])
} }

View File

@ -40,7 +40,7 @@ func (p *plan) Apply(ctx context.Context, b *bundle.Bundle) error {
} }
// Persist computed plan // Persist computed plan
tfDir, err := Dir(b) tfDir, err := Dir(ctx, b)
if err != nil { if err != nil {
return err return err
} }

View File

@ -25,7 +25,7 @@ func (l *statePull) Apply(ctx context.Context, b *bundle.Bundle) error {
return err return err
} }
dir, err := Dir(b) dir, err := Dir(ctx, b)
if err != nil { if err != nil {
return err return err
} }

View File

@ -22,7 +22,7 @@ func (l *statePush) Apply(ctx context.Context, b *bundle.Bundle) error {
return err return err
} }
dir, err := Dir(b) dir, err := Dir(ctx, b)
if err != nil { if err != nil {
return err return err
} }
@ -32,6 +32,7 @@ func (l *statePush) Apply(ctx context.Context, b *bundle.Bundle) error {
if err != nil { if err != nil {
return err return err
} }
defer local.Close()
// Upload state file from local cache directory to filer. // Upload state file from local cache directory to filer.
log.Infof(ctx, "Writing local state file to remote state directory") log.Infof(ctx, "Writing local state file to remote state directory")

View File

@ -16,12 +16,13 @@ func (w *write) Name() string {
} }
func (w *write) Apply(ctx context.Context, b *bundle.Bundle) error { func (w *write) Apply(ctx context.Context, b *bundle.Bundle) error {
dir, err := Dir(b) dir, err := Dir(ctx, b)
if err != nil { if err != nil {
return err return err
} }
root := BundleToTerraform(&b.Config) root, noResources := BundleToTerraform(&b.Config)
b.TerraformHasNoResources = noResources
f, err := os.Create(filepath.Join(dir, "bundle.tf.json")) f, err := os.Create(filepath.Join(dir, "bundle.tf.json"))
if err != nil { if err != nil {
return err return err

18
bundle/env/env.go vendored Normal file
View File

@ -0,0 +1,18 @@
package env
import (
"context"
envlib "github.com/databricks/cli/libs/env"
)
// Return the value of the first environment variable that is set.
func get(ctx context.Context, variables []string) (string, bool) {
for _, v := range variables {
value, ok := envlib.Lookup(ctx, v)
if ok {
return value, true
}
}
return "", false
}

44
bundle/env/env_test.go vendored Normal file
View File

@ -0,0 +1,44 @@
package env
import (
"context"
"testing"
"github.com/databricks/cli/internal/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetWithRealEnvSingleVariable(t *testing.T) {
testutil.CleanupEnvironment(t)
t.Setenv("v1", "foo")
v, ok := get(context.Background(), []string{"v1"})
require.True(t, ok)
assert.Equal(t, "foo", v)
// Not set.
v, ok = get(context.Background(), []string{"v2"})
require.False(t, ok)
assert.Equal(t, "", v)
}
func TestGetWithRealEnvMultipleVariables(t *testing.T) {
testutil.CleanupEnvironment(t)
t.Setenv("v1", "foo")
for _, vars := range [][]string{
{"v1", "v2", "v3"},
{"v2", "v3", "v1"},
{"v3", "v1", "v2"},
} {
v, ok := get(context.Background(), vars)
require.True(t, ok)
assert.Equal(t, "foo", v)
}
// Not set.
v, ok := get(context.Background(), []string{"v2", "v3", "v4"})
require.False(t, ok)
assert.Equal(t, "", v)
}

14
bundle/env/includes.go vendored Normal file
View File

@ -0,0 +1,14 @@
package env
import "context"
// IncludesVariable names the environment variable that holds additional configuration paths to include
// during bundle configuration loading. Also see `bundle/config/mutator/process_root_includes.go`.
const IncludesVariable = "DATABRICKS_BUNDLE_INCLUDES"
// Includes returns the bundle Includes environment variable.
func Includes(ctx context.Context) (string, bool) {
return get(ctx, []string{
IncludesVariable,
})
}

28
bundle/env/includes_test.go vendored Normal file
View File

@ -0,0 +1,28 @@
package env
import (
"context"
"testing"
"github.com/databricks/cli/internal/testutil"
"github.com/stretchr/testify/assert"
)
func TestIncludes(t *testing.T) {
ctx := context.Background()
testutil.CleanupEnvironment(t)
t.Run("set", func(t *testing.T) {
t.Setenv("DATABRICKS_BUNDLE_INCLUDES", "foo")
includes, ok := Includes(ctx)
assert.True(t, ok)
assert.Equal(t, "foo", includes)
})
t.Run("not set", func(t *testing.T) {
includes, ok := Includes(ctx)
assert.False(t, ok)
assert.Equal(t, "", includes)
})
}

16
bundle/env/root.go vendored Normal file
View File

@ -0,0 +1,16 @@
package env
import "context"
// RootVariable names the environment variable that holds the bundle root path.
const RootVariable = "DATABRICKS_BUNDLE_ROOT"
// Root returns the bundle root environment variable.
func Root(ctx context.Context) (string, bool) {
return get(ctx, []string{
RootVariable,
// Primary variable name for the bundle root until v0.204.0.
"BUNDLE_ROOT",
})
}

43
bundle/env/root_test.go vendored Normal file
View File

@ -0,0 +1,43 @@
package env
import (
"context"
"testing"
"github.com/databricks/cli/internal/testutil"
"github.com/stretchr/testify/assert"
)
func TestRoot(t *testing.T) {
ctx := context.Background()
testutil.CleanupEnvironment(t)
t.Run("first", func(t *testing.T) {
t.Setenv("DATABRICKS_BUNDLE_ROOT", "foo")
root, ok := Root(ctx)
assert.True(t, ok)
assert.Equal(t, "foo", root)
})
t.Run("second", func(t *testing.T) {
t.Setenv("BUNDLE_ROOT", "foo")
root, ok := Root(ctx)
assert.True(t, ok)
assert.Equal(t, "foo", root)
})
t.Run("both set", func(t *testing.T) {
t.Setenv("DATABRICKS_BUNDLE_ROOT", "first")
t.Setenv("BUNDLE_ROOT", "second")
root, ok := Root(ctx)
assert.True(t, ok)
assert.Equal(t, "first", root)
})
t.Run("not set", func(t *testing.T) {
root, ok := Root(ctx)
assert.False(t, ok)
assert.Equal(t, "", root)
})
}

17
bundle/env/target.go vendored Normal file
View File

@ -0,0 +1,17 @@
package env
import "context"
// TargetVariable names the environment variable that holds the bundle target to use.
const TargetVariable = "DATABRICKS_BUNDLE_TARGET"
// Target returns the bundle target environment variable.
func Target(ctx context.Context) (string, bool) {
return get(ctx, []string{
TargetVariable,
// Primary variable name for the bundle target until v0.203.2.
// See https://github.com/databricks/cli/pull/670.
"DATABRICKS_BUNDLE_ENV",
})
}

43
bundle/env/target_test.go vendored Normal file
View File

@ -0,0 +1,43 @@
package env
import (
"context"
"testing"
"github.com/databricks/cli/internal/testutil"
"github.com/stretchr/testify/assert"
)
func TestTarget(t *testing.T) {
ctx := context.Background()
testutil.CleanupEnvironment(t)
t.Run("first", func(t *testing.T) {
t.Setenv("DATABRICKS_BUNDLE_TARGET", "foo")
target, ok := Target(ctx)
assert.True(t, ok)
assert.Equal(t, "foo", target)
})
t.Run("second", func(t *testing.T) {
t.Setenv("DATABRICKS_BUNDLE_ENV", "foo")
target, ok := Target(ctx)
assert.True(t, ok)
assert.Equal(t, "foo", target)
})
t.Run("both set", func(t *testing.T) {
t.Setenv("DATABRICKS_BUNDLE_TARGET", "first")
t.Setenv("DATABRICKS_BUNDLE_ENV", "second")
target, ok := Target(ctx)
assert.True(t, ok)
assert.Equal(t, "first", target)
})
t.Run("not set", func(t *testing.T) {
target, ok := Target(ctx)
assert.False(t, ok)
assert.Equal(t, "", target)
})
}

13
bundle/env/temp_dir.go vendored Normal file
View File

@ -0,0 +1,13 @@
package env
import "context"
// TempDirVariable names the environment variable that holds the temporary directory to use.
const TempDirVariable = "DATABRICKS_BUNDLE_TMP"
// TempDir returns the temporary directory to use.
func TempDir(ctx context.Context) (string, bool) {
return get(ctx, []string{
TempDirVariable,
})
}

28
bundle/env/temp_dir_test.go vendored Normal file
View File

@ -0,0 +1,28 @@
package env
import (
"context"
"testing"
"github.com/databricks/cli/internal/testutil"
"github.com/stretchr/testify/assert"
)
func TestTempDir(t *testing.T) {
ctx := context.Background()
testutil.CleanupEnvironment(t)
t.Run("set", func(t *testing.T) {
t.Setenv("DATABRICKS_BUNDLE_TMP", "foo")
tempDir, ok := TempDir(ctx)
assert.True(t, ok)
assert.Equal(t, "foo", tempDir)
})
t.Run("not set", func(t *testing.T) {
tempDir, ok := TempDir(ctx)
assert.False(t, ok)
assert.Equal(t, "", tempDir)
})
}

View File

@ -56,11 +56,11 @@ func findAllTasks(b *bundle.Bundle) []*jobs.Task {
return result return result
} }
func FindAllWheelTasks(b *bundle.Bundle) []*jobs.Task { func FindAllWheelTasksWithLocalLibraries(b *bundle.Bundle) []*jobs.Task {
tasks := findAllTasks(b) tasks := findAllTasks(b)
wheelTasks := make([]*jobs.Task, 0) wheelTasks := make([]*jobs.Task, 0)
for _, task := range tasks { for _, task := range tasks {
if task.PythonWheelTask != nil { if task.PythonWheelTask != nil && IsTaskWithLocalLibraries(task) {
wheelTasks = append(wheelTasks, task) wheelTasks = append(wheelTasks, task)
} }
} }
@ -68,6 +68,27 @@ func FindAllWheelTasks(b *bundle.Bundle) []*jobs.Task {
return wheelTasks return wheelTasks
} }
func IsTaskWithLocalLibraries(task *jobs.Task) bool {
for _, l := range task.Libraries {
if isLocalLibrary(&l) {
return true
}
}
return false
}
func IsTaskWithWorkspaceLibraries(task *jobs.Task) bool {
for _, l := range task.Libraries {
path := libPath(&l)
if isWorkspacePath(path) {
return true
}
}
return false
}
func isMissingRequiredLibraries(task *jobs.Task) bool { func isMissingRequiredLibraries(task *jobs.Task) bool {
if task.Libraries != nil { if task.Libraries != nil {
return false return false
@ -165,8 +186,8 @@ func isRemoteStorageScheme(path string) bool {
return false return false
} }
// If the path starts with scheme:// format, it's a correct remote storage scheme // If the path starts with scheme:/ format, it's a correct remote storage scheme
return strings.HasPrefix(path, url.Scheme+"://") return strings.HasPrefix(path, url.Scheme+":/")
} }

View File

@ -16,6 +16,7 @@ var testCases map[string]bool = map[string]bool{
"file://path/to/package": true, "file://path/to/package": true,
"C:\\path\\to\\package": true, "C:\\path\\to\\package": true,
"dbfs://path/to/package": false, "dbfs://path/to/package": false,
"dbfs:/path/to/package": false,
"s3://path/to/package": false, "s3://path/to/package": false,
"abfss://path/to/package": false, "abfss://path/to/package": false,
} }

View File

@ -7,6 +7,7 @@ import (
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config/mutator" "github.com/databricks/cli/bundle/config/mutator"
"github.com/databricks/cli/bundle/libraries"
jobs_utils "github.com/databricks/cli/libs/jobs" jobs_utils "github.com/databricks/cli/libs/jobs"
"github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/jobs"
) )
@ -70,10 +71,13 @@ func (t *pythonTrampoline) GetTemplate(b *bundle.Bundle, task *jobs.Task) (strin
func (t *pythonTrampoline) GetTasks(b *bundle.Bundle) []jobs_utils.TaskWithJobKey { func (t *pythonTrampoline) GetTasks(b *bundle.Bundle) []jobs_utils.TaskWithJobKey {
return jobs_utils.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool { return jobs_utils.GetTasksWithJobKeyBy(b, func(task *jobs.Task) bool {
return task.PythonWheelTask != nil return task.PythonWheelTask != nil && needsTrampoline(task)
}) })
} }
func needsTrampoline(task *jobs.Task) bool {
return libraries.IsTaskWithWorkspaceLibraries(task)
}
func (t *pythonTrampoline) GetTemplateData(_ *bundle.Bundle, task *jobs.Task) (map[string]any, error) { func (t *pythonTrampoline) GetTemplateData(_ *bundle.Bundle, task *jobs.Task) (map[string]any, error) {
params, err := t.generateParameters(task.PythonWheelTask) params, err := t.generateParameters(task.PythonWheelTask)
if err != nil { if err != nil {

View File

@ -9,6 +9,7 @@ import (
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/paths" "github.com/databricks/cli/bundle/config/paths"
"github.com/databricks/cli/bundle/config/resources" "github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -82,11 +83,21 @@ func TestTransformFiltersWheelTasksOnly(t *testing.T) {
{ {
TaskKey: "key1", TaskKey: "key1",
PythonWheelTask: &jobs.PythonWheelTask{}, PythonWheelTask: &jobs.PythonWheelTask{},
Libraries: []compute.Library{
{Whl: "/Workspace/Users/test@test.com/bundle/dist/test.whl"},
},
}, },
{ {
TaskKey: "key2", TaskKey: "key2",
NotebookTask: &jobs.NotebookTask{}, NotebookTask: &jobs.NotebookTask{},
}, },
{
TaskKey: "key3",
PythonWheelTask: &jobs.PythonWheelTask{},
Libraries: []compute.Library{
{Whl: "dbfs:/FileStore/dist/test.whl"},
},
},
}, },
}, },
}, },

View File

@ -1,21 +1,21 @@
package bundle package bundle
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/env"
"github.com/databricks/cli/folders" "github.com/databricks/cli/folders"
) )
const envBundleRoot = "BUNDLE_ROOT" // getRootEnv returns the value of the bundle root environment variable
// getRootEnv returns the value of the `BUNDLE_ROOT` environment variable
// if it set and is a directory. If the environment variable is set but // if it set and is a directory. If the environment variable is set but
// is not a directory, it returns an error. If the environment variable is // is not a directory, it returns an error. If the environment variable is
// not set, it returns an empty string. // not set, it returns an empty string.
func getRootEnv() (string, error) { func getRootEnv(ctx context.Context) (string, error) {
path, ok := os.LookupEnv(envBundleRoot) path, ok := env.Root(ctx)
if !ok { if !ok {
return "", nil return "", nil
} }
@ -24,7 +24,7 @@ func getRootEnv() (string, error) {
err = fmt.Errorf("not a directory") err = fmt.Errorf("not a directory")
} }
if err != nil { if err != nil {
return "", fmt.Errorf(`invalid bundle root %s="%s": %w`, envBundleRoot, path, err) return "", fmt.Errorf(`invalid bundle root %s="%s": %w`, env.RootVariable, path, err)
} }
return path, nil return path, nil
} }
@ -48,8 +48,8 @@ func getRootWithTraversal() (string, error) {
} }
// mustGetRoot returns a bundle root or an error if one cannot be found. // mustGetRoot returns a bundle root or an error if one cannot be found.
func mustGetRoot() (string, error) { func mustGetRoot(ctx context.Context) (string, error) {
path, err := getRootEnv() path, err := getRootEnv(ctx)
if path != "" || err != nil { if path != "" || err != nil {
return path, err return path, err
} }
@ -57,9 +57,9 @@ func mustGetRoot() (string, error) {
} }
// tryGetRoot returns a bundle root or an empty string if one cannot be found. // tryGetRoot returns a bundle root or an empty string if one cannot be found.
func tryGetRoot() (string, error) { func tryGetRoot(ctx context.Context) (string, error) {
// Note: an invalid value in the environment variable is still an error. // Note: an invalid value in the environment variable is still an error.
path, err := getRootEnv() path, err := getRootEnv(ctx)
if path != "" || err != nil { if path != "" || err != nil {
return path, err return path, err
} }

View File

@ -7,6 +7,7 @@ import (
"testing" "testing"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/env"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -32,49 +33,55 @@ func chdir(t *testing.T, dir string) string {
} }
func TestRootFromEnv(t *testing.T) { func TestRootFromEnv(t *testing.T) {
ctx := context.Background()
dir := t.TempDir() dir := t.TempDir()
t.Setenv(envBundleRoot, dir) t.Setenv(env.RootVariable, dir)
// It should pull the root from the environment variable. // It should pull the root from the environment variable.
root, err := mustGetRoot() root, err := mustGetRoot(ctx)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, root, dir) require.Equal(t, root, dir)
} }
func TestRootFromEnvDoesntExist(t *testing.T) { func TestRootFromEnvDoesntExist(t *testing.T) {
ctx := context.Background()
dir := t.TempDir() dir := t.TempDir()
t.Setenv(envBundleRoot, filepath.Join(dir, "doesntexist")) t.Setenv(env.RootVariable, filepath.Join(dir, "doesntexist"))
// It should pull the root from the environment variable. // It should pull the root from the environment variable.
_, err := mustGetRoot() _, err := mustGetRoot(ctx)
require.Errorf(t, err, "invalid bundle root") require.Errorf(t, err, "invalid bundle root")
} }
func TestRootFromEnvIsFile(t *testing.T) { func TestRootFromEnvIsFile(t *testing.T) {
ctx := context.Background()
dir := t.TempDir() dir := t.TempDir()
f, err := os.Create(filepath.Join(dir, "invalid")) f, err := os.Create(filepath.Join(dir, "invalid"))
require.NoError(t, err) require.NoError(t, err)
f.Close() f.Close()
t.Setenv(envBundleRoot, f.Name()) t.Setenv(env.RootVariable, f.Name())
// It should pull the root from the environment variable. // It should pull the root from the environment variable.
_, err = mustGetRoot() _, err = mustGetRoot(ctx)
require.Errorf(t, err, "invalid bundle root") require.Errorf(t, err, "invalid bundle root")
} }
func TestRootIfEnvIsEmpty(t *testing.T) { func TestRootIfEnvIsEmpty(t *testing.T) {
ctx := context.Background()
dir := "" dir := ""
t.Setenv(envBundleRoot, dir) t.Setenv(env.RootVariable, dir)
// It should pull the root from the environment variable. // It should pull the root from the environment variable.
_, err := mustGetRoot() _, err := mustGetRoot(ctx)
require.Errorf(t, err, "invalid bundle root") require.Errorf(t, err, "invalid bundle root")
} }
func TestRootLookup(t *testing.T) { func TestRootLookup(t *testing.T) {
ctx := context.Background()
// Have to set then unset to allow the testing package to revert it to its original value. // Have to set then unset to allow the testing package to revert it to its original value.
t.Setenv(envBundleRoot, "") t.Setenv(env.RootVariable, "")
os.Unsetenv(envBundleRoot) os.Unsetenv(env.RootVariable)
chdir(t, t.TempDir()) chdir(t, t.TempDir())
@ -89,27 +96,30 @@ func TestRootLookup(t *testing.T) {
// It should find the project root from $PWD. // It should find the project root from $PWD.
wd := chdir(t, "./a/b/c") wd := chdir(t, "./a/b/c")
root, err := mustGetRoot() root, err := mustGetRoot(ctx)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, wd, root) require.Equal(t, wd, root)
} }
func TestRootLookupError(t *testing.T) { func TestRootLookupError(t *testing.T) {
ctx := context.Background()
// Have to set then unset to allow the testing package to revert it to its original value. // Have to set then unset to allow the testing package to revert it to its original value.
t.Setenv(envBundleRoot, "") t.Setenv(env.RootVariable, "")
os.Unsetenv(envBundleRoot) os.Unsetenv(env.RootVariable)
// It can't find a project root from a temporary directory. // It can't find a project root from a temporary directory.
_ = chdir(t, t.TempDir()) _ = chdir(t, t.TempDir())
_, err := mustGetRoot() _, err := mustGetRoot(ctx)
require.ErrorContains(t, err, "unable to locate bundle root") require.ErrorContains(t, err, "unable to locate bundle root")
} }
func TestLoadYamlWhenIncludesEnvPresent(t *testing.T) { func TestLoadYamlWhenIncludesEnvPresent(t *testing.T) {
ctx := context.Background()
chdir(t, filepath.Join(".", "tests", "basic")) chdir(t, filepath.Join(".", "tests", "basic"))
t.Setenv(ExtraIncludePathsKey, "test") t.Setenv(env.IncludesVariable, "test")
bundle, err := MustLoad(context.Background()) bundle, err := MustLoad(ctx)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "basic", bundle.Config.Bundle.Name) assert.Equal(t, "basic", bundle.Config.Bundle.Name)
@ -119,30 +129,33 @@ func TestLoadYamlWhenIncludesEnvPresent(t *testing.T) {
} }
func TestLoadDefautlBundleWhenNoYamlAndRootAndIncludesEnvPresent(t *testing.T) { func TestLoadDefautlBundleWhenNoYamlAndRootAndIncludesEnvPresent(t *testing.T) {
ctx := context.Background()
dir := t.TempDir() dir := t.TempDir()
chdir(t, dir) chdir(t, dir)
t.Setenv(envBundleRoot, dir) t.Setenv(env.RootVariable, dir)
t.Setenv(ExtraIncludePathsKey, "test") t.Setenv(env.IncludesVariable, "test")
bundle, err := MustLoad(context.Background()) bundle, err := MustLoad(ctx)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, dir, bundle.Config.Path) assert.Equal(t, dir, bundle.Config.Path)
} }
func TestErrorIfNoYamlNoRootEnvAndIncludesEnvPresent(t *testing.T) { func TestErrorIfNoYamlNoRootEnvAndIncludesEnvPresent(t *testing.T) {
ctx := context.Background()
dir := t.TempDir() dir := t.TempDir()
chdir(t, dir) chdir(t, dir)
t.Setenv(ExtraIncludePathsKey, "test") t.Setenv(env.IncludesVariable, "test")
_, err := MustLoad(context.Background()) _, err := MustLoad(ctx)
assert.Error(t, err) assert.Error(t, err)
} }
func TestErrorIfNoYamlNoIncludesEnvAndRootEnvPresent(t *testing.T) { func TestErrorIfNoYamlNoIncludesEnvAndRootEnvPresent(t *testing.T) {
ctx := context.Background()
dir := t.TempDir() dir := t.TempDir()
chdir(t, dir) chdir(t, dir)
t.Setenv(envBundleRoot, dir) t.Setenv(env.RootVariable, dir)
_, err := MustLoad(context.Background()) _, err := MustLoad(ctx)
assert.Error(t, err) assert.Error(t, err)
} }

View File

@ -95,6 +95,13 @@ type jobRunner struct {
job *resources.Job job *resources.Job
} }
func (r *jobRunner) Name() string {
if r.job == nil || r.job.JobSettings == nil {
return ""
}
return r.job.JobSettings.Name
}
func isFailed(task jobs.RunTask) bool { func isFailed(task jobs.RunTask) bool {
return task.State.LifeCycleState == jobs.RunLifeCycleStateInternalError || return task.State.LifeCycleState == jobs.RunLifeCycleStateInternalError ||
(task.State.LifeCycleState == jobs.RunLifeCycleStateTerminated && (task.State.LifeCycleState == jobs.RunLifeCycleStateTerminated &&

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"golang.org/x/exp/maps"
) )
// RunnerLookup maps identifiers to a list of workloads that match that identifier. // RunnerLookup maps identifiers to a list of workloads that match that identifier.
@ -32,18 +33,20 @@ func ResourceKeys(b *bundle.Bundle) (keyOnly RunnerLookup, keyWithType RunnerLoo
return return
} }
// ResourceCompletions returns a list of keys that unambiguously reference resources in the bundle. // ResourceCompletionMap returns a map of resource keys to their respective names.
func ResourceCompletions(b *bundle.Bundle) []string { func ResourceCompletionMap(b *bundle.Bundle) map[string]string {
seen := make(map[string]bool) out := make(map[string]string)
comps := []string{}
keyOnly, keyWithType := ResourceKeys(b) keyOnly, keyWithType := ResourceKeys(b)
// Keep track of resources we have seen by their fully qualified key.
seen := make(map[string]bool)
// First add resources that can be identified by key alone. // First add resources that can be identified by key alone.
for k, v := range keyOnly { for k, v := range keyOnly {
// Invariant: len(v) >= 1. See [ResourceKeys]. // Invariant: len(v) >= 1. See [ResourceKeys].
if len(v) == 1 { if len(v) == 1 {
seen[v[0].Key()] = true seen[v[0].Key()] = true
comps = append(comps, k) out[k] = v[0].Name()
} }
} }
@ -54,8 +57,13 @@ func ResourceCompletions(b *bundle.Bundle) []string {
if ok { if ok {
continue continue
} }
comps = append(comps, k) out[k] = v[0].Name()
} }
return comps return out
}
// ResourceCompletions returns a list of keys that unambiguously reference resources in the bundle.
func ResourceCompletions(b *bundle.Bundle) []string {
return maps.Keys(ResourceCompletionMap(b))
} }

View File

@ -60,7 +60,7 @@ func GetJobOutput(ctx context.Context, w *databricks.WorkspaceClient, runId int6
return nil, err return nil, err
} }
result := &JobOutput{ result := &JobOutput{
TaskOutputs: make([]TaskOutput, len(jobRun.Tasks)), TaskOutputs: make([]TaskOutput, 0),
} }
for _, task := range jobRun.Tasks { for _, task := range jobRun.Tasks {
jobRunOutput, err := w.Jobs.GetRunOutput(ctx, jobs.GetRunOutputRequest{ jobRunOutput, err := w.Jobs.GetRunOutput(ctx, jobs.GetRunOutputRequest{
@ -69,7 +69,11 @@ func GetJobOutput(ctx context.Context, w *databricks.WorkspaceClient, runId int6
if err != nil { if err != nil {
return nil, err return nil, err
} }
task := TaskOutput{TaskKey: task.TaskKey, Output: toRunOutput(jobRunOutput), EndTime: task.EndTime} out := toRunOutput(jobRunOutput)
if out == nil {
continue
}
task := TaskOutput{TaskKey: task.TaskKey, Output: out, EndTime: task.EndTime}
result.TaskOutputs = append(result.TaskOutputs, task) result.TaskOutputs = append(result.TaskOutputs, task)
} }
return result, nil return result, nil

View File

@ -136,6 +136,13 @@ type pipelineRunner struct {
pipeline *resources.Pipeline pipeline *resources.Pipeline
} }
func (r *pipelineRunner) Name() string {
if r.pipeline == nil || r.pipeline.PipelineSpec == nil {
return ""
}
return r.pipeline.PipelineSpec.Name
}
func (r *pipelineRunner) Run(ctx context.Context, opts *Options) (output.RunOutput, error) { func (r *pipelineRunner) Run(ctx context.Context, opts *Options) (output.RunOutput, error) {
var pipelineID = r.pipeline.ID var pipelineID = r.pipeline.ID

View File

@ -21,6 +21,9 @@ type Runner interface {
// This is used for showing the user hints w.r.t. disambiguation. // This is used for showing the user hints w.r.t. disambiguation.
Key() string Key() string
// Name returns the resource's name, if defined.
Name() string
// Run the underlying worklow. // Run the underlying worklow.
Run(ctx context.Context, opts *Options) (output.RunOutput, error) Run(ctx context.Context, opts *Options) (output.RunOutput, error)
} }

View File

@ -1441,6 +1441,86 @@
} }
} }
}, },
"model_serving_endpoints": {
"description": "List of Model Serving Endpoints",
"additionalproperties": {
"description": "",
"properties": {
"name": {
"description": "The name of the model serving endpoint. This field is required and must be unique across a workspace. An endpoint name can consist of alphanumeric characters, dashes, and underscores. NOTE: Changing this name will delete the existing endpoint and create a new endpoint with the update name."
},
"permissions": {
"description": "",
"items": {
"description": "",
"properties": {
"group_name": {
"description": ""
},
"level": {
"description": ""
},
"service_principal_name": {
"description": ""
},
"user_name": {
"description": ""
}
}
}
},
"config": {
"description": "The model serving endpoint configuration.",
"properties": {
"properties": {
"served_models": {
"description": "Each block represents a served model for the endpoint to serve. A model serving endpoint can have up to 10 served models.",
"items": {
"description": "",
"properties" : {
"name": {
"description": "The name of a served model. It must be unique across an endpoint. If not specified, this field will default to modelname-modelversion. A served model name can consist of alphanumeric characters, dashes, and underscores."
},
"model_name": {
"description": "The name of the model in Databricks Model Registry to be served."
},
"model_version": {
"description": "The version of the model in Databricks Model Registry to be served."
},
"workload_size": {
"description": "The workload size of the served model. The workload size corresponds to a range of provisioned concurrency that the compute will autoscale between. A single unit of provisioned concurrency can process one request at a time. Valid workload sizes are \"Small\" (4 - 4 provisioned concurrency), \"Medium\" (8 - 16 provisioned concurrency), and \"Large\" (16 - 64 provisioned concurrency)."
},
"scale_to_zero_enabled": {
"description": "Whether the compute resources for the served model should scale down to zero. If scale-to-zero is enabled, the lower bound of the provisioned concurrency for each workload size will be 0."
}
}
}
},
"traffic_config": {
"description": "A single block represents the traffic split configuration amongst the served models.",
"properties": {
"routes": {
"description": "Each block represents a route that defines traffic to each served model. Each served_models block needs to have a corresponding routes block.",
"items": {
"description": "",
"properties": {
"served_model_name": {
"description": "The name of the served model this route configures traffic for. This needs to match the name of a served_models block."
},
"traffic_percentage": {
"description": "The percentage of endpoint traffic to send to this route. It must be an integer between 0 and 100 inclusive."
}
}
}
}
}
}
}
}
}
}
}
},
"pipelines": { "pipelines": {
"description": "List of DLT pipelines", "description": "List of DLT pipelines",
"additionalproperties": { "additionalproperties": {

View File

@ -210,6 +210,19 @@ func (reader *OpenapiReader) modelsDocs() (*Docs, error) {
return modelsDocs, nil return modelsDocs, nil
} }
func (reader *OpenapiReader) modelServingEndpointsDocs() (*Docs, error) {
modelServingEndpointsSpecSchema, err := reader.readResolvedSchema(SchemaPathPrefix + "serving.CreateServingEndpoint")
if err != nil {
return nil, err
}
modelServingEndpointsDocs := schemaToDocs(modelServingEndpointsSpecSchema)
modelServingEndpointsAllDocs := &Docs{
Description: "List of Model Serving Endpoints",
AdditionalProperties: modelServingEndpointsDocs,
}
return modelServingEndpointsAllDocs, nil
}
func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) { func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) {
jobsDocs, err := reader.jobsDocs() jobsDocs, err := reader.jobsDocs()
if err != nil { if err != nil {
@ -227,6 +240,10 @@ func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
modelServingEndpointsDocs, err := reader.modelServingEndpointsDocs()
if err != nil {
return nil, err
}
return &Docs{ return &Docs{
Description: "Collection of Databricks resources to deploy.", Description: "Collection of Databricks resources to deploy.",
@ -235,6 +252,7 @@ func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) {
"pipelines": pipelinesDocs, "pipelines": pipelinesDocs,
"experiments": experimentsDocs, "experiments": experimentsDocs,
"models": modelsDocs, "models": modelsDocs,
"model_serving_endpoints": modelServingEndpointsDocs,
}, },
}, nil }, nil
} }

View File

@ -12,4 +12,4 @@ resources:
package_name: "my_test_code" package_name: "my_test_code"
entry_point: "run" entry_point: "run"
libraries: libraries:
- whl: dbfs://path/to/dist/mywheel.whl - whl: dbfs:/path/to/dist/mywheel.whl

View File

@ -0,0 +1,38 @@
resources:
model_serving_endpoints:
my_model_serving_endpoint:
name: "my-endpoint"
config:
served_models:
- model_name: "model-name"
model_version: "1"
workload_size: "Small"
scale_to_zero_enabled: true
traffic_config:
routes:
- served_model_name: "model-name-1"
traffic_percentage: 100
permissions:
- level: CAN_QUERY
group_name: users
targets:
development:
mode: development
resources:
model_serving_endpoints:
my_model_serving_endpoint:
name: "my-dev-endpoint"
staging:
resources:
model_serving_endpoints:
my_model_serving_endpoint:
name: "my-staging-endpoint"
production:
mode: production
resources:
model_serving_endpoints:
my_model_serving_endpoint:
name: "my-prod-endpoint"

View File

@ -0,0 +1,48 @@
package config_tests
import (
"path/filepath"
"testing"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/stretchr/testify/assert"
)
func assertExpected(t *testing.T, p *resources.ModelServingEndpoint) {
assert.Equal(t, "model_serving_endpoint/databricks.yml", filepath.ToSlash(p.ConfigFilePath))
assert.Equal(t, "model-name", p.Config.ServedModels[0].ModelName)
assert.Equal(t, "1", p.Config.ServedModels[0].ModelVersion)
assert.Equal(t, "model-name-1", p.Config.TrafficConfig.Routes[0].ServedModelName)
assert.Equal(t, 100, p.Config.TrafficConfig.Routes[0].TrafficPercentage)
assert.Equal(t, "users", p.Permissions[0].GroupName)
assert.Equal(t, "CAN_QUERY", p.Permissions[0].Level)
}
func TestModelServingEndpointDevelopment(t *testing.T) {
b := loadTarget(t, "./model_serving_endpoint", "development")
assert.Len(t, b.Config.Resources.ModelServingEndpoints, 1)
assert.Equal(t, b.Config.Bundle.Mode, config.Development)
p := b.Config.Resources.ModelServingEndpoints["my_model_serving_endpoint"]
assert.Equal(t, "my-dev-endpoint", p.Name)
assertExpected(t, p)
}
func TestModelServingEndpointStaging(t *testing.T) {
b := loadTarget(t, "./model_serving_endpoint", "staging")
assert.Len(t, b.Config.Resources.ModelServingEndpoints, 1)
p := b.Config.Resources.ModelServingEndpoints["my_model_serving_endpoint"]
assert.Equal(t, "my-staging-endpoint", p.Name)
assertExpected(t, p)
}
func TestModelServingEndpointProduction(t *testing.T) {
b := loadTarget(t, "./model_serving_endpoint", "production")
assert.Len(t, b.Config.Resources.ModelServingEndpoints, 1)
p := b.Config.Resources.ModelServingEndpoints["my_model_serving_endpoint"]
assert.Equal(t, "my-prod-endpoint", p.Name)
assertExpected(t, p)
}

View File

@ -0,0 +1,17 @@
package config_tests
import (
"path/filepath"
"testing"
"github.com/databricks/cli/internal"
"github.com/stretchr/testify/require"
)
func TestSuggestTargetIfWrongPassed(t *testing.T) {
t.Setenv("BUNDLE_ROOT", filepath.Join("target_overrides", "workspace"))
_, _, err := internal.RequireErrorRun(t, "bundle", "validate", "-e", "incorrect")
require.ErrorContains(t, err, "Available targets:")
require.ErrorContains(t, err, "development")
require.ErrorContains(t, err, "staging")
}

View File

@ -7,7 +7,7 @@ import (
func New() *cobra.Command { func New() *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "bundle", Use: "bundle",
Short: "Databricks Asset Bundles", Short: "Databricks Asset Bundles\n\nOnline documentation: https://docs.databricks.com/en/dev-tools/bundles",
} }
initVariableFlag(cmd) initVariableFlag(cmd)

View File

@ -74,22 +74,21 @@ func newInitCommand() *cobra.Command {
return template.Materialize(ctx, configFile, templatePath, outputDir) return template.Materialize(ctx, configFile, templatePath, outputDir)
} }
// Download the template in a temporary directory // Create a temporary directory with the name of the repository. The '*'
tmpDir := os.TempDir() // character is replaced by a random string in the generated temporary directory.
templateURL := templatePath repoDir, err := os.MkdirTemp("", repoName(templatePath)+"-*")
repoDir := filepath.Join(tmpDir, repoName(templateURL))
err := os.MkdirAll(repoDir, 0755)
if err != nil { if err != nil {
return err return err
} }
// TODO: Add automated test that the downloaded git repo is cleaned up. // TODO: Add automated test that the downloaded git repo is cleaned up.
err = git.Clone(ctx, templateURL, "", repoDir) // Clone the repository in the temporary directory
err = git.Clone(ctx, templatePath, "", repoDir)
if err != nil { if err != nil {
return err return err
} }
defer os.RemoveAll(templateDir) // Clean up downloaded repository once the template is materialized.
defer os.RemoveAll(repoDir)
return template.Materialize(ctx, configFile, filepath.Join(repoDir, templateDir), outputDir) return template.Materialize(ctx, configFile, filepath.Join(repoDir, templateDir), outputDir)
} }
return cmd return cmd
} }

View File

@ -9,6 +9,7 @@ import (
"github.com/databricks/cli/bundle/phases" "github.com/databricks/cli/bundle/phases"
"github.com/databricks/cli/bundle/run" "github.com/databricks/cli/bundle/run"
"github.com/databricks/cli/cmd/root" "github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/flags"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -16,9 +17,9 @@ import (
func newRunCommand() *cobra.Command { func newRunCommand() *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "run [flags] KEY", Use: "run [flags] KEY",
Short: "Run a workload (e.g. a job or a pipeline)", Short: "Run a resource (e.g. a job or a pipeline)",
Args: cobra.ExactArgs(1), Args: cobra.MaximumNArgs(1),
PreRunE: ConfigureBundleWithVariables, PreRunE: ConfigureBundleWithVariables,
} }
@ -29,9 +30,10 @@ func newRunCommand() *cobra.Command {
cmd.Flags().BoolVar(&noWait, "no-wait", false, "Don't wait for the run to complete.") cmd.Flags().BoolVar(&noWait, "no-wait", false, "Don't wait for the run to complete.")
cmd.RunE = func(cmd *cobra.Command, args []string) error { cmd.RunE = func(cmd *cobra.Command, args []string) error {
b := bundle.Get(cmd.Context()) ctx := cmd.Context()
b := bundle.Get(ctx)
err := bundle.Apply(cmd.Context(), b, bundle.Seq( err := bundle.Apply(ctx, b, bundle.Seq(
phases.Initialize(), phases.Initialize(),
terraform.Interpolate(), terraform.Interpolate(),
terraform.Write(), terraform.Write(),
@ -42,13 +44,31 @@ func newRunCommand() *cobra.Command {
return err return err
} }
// If no arguments are specified, prompt the user to select something to run.
if len(args) == 0 && cmdio.IsInteractive(ctx) {
// Invert completions from KEY -> NAME, to NAME -> KEY.
inv := make(map[string]string)
for k, v := range run.ResourceCompletionMap(b) {
inv[v] = k
}
id, err := cmdio.Select(ctx, inv, "Resource to run")
if err != nil {
return err
}
args = append(args, id)
}
if len(args) != 1 {
return fmt.Errorf("expected a KEY of the resource to run")
}
runner, err := run.Find(b, args[0]) runner, err := run.Find(b, args[0])
if err != nil { if err != nil {
return err return err
} }
runOptions.NoWait = noWait runOptions.NoWait = noWait
output, err := runner.Run(cmd.Context(), &runOptions) output, err := runner.Run(ctx, &runOptions)
if err != nil { if err != nil {
return err return err
} }

View File

@ -18,12 +18,12 @@ type syncFlags struct {
} }
func (f *syncFlags) syncOptionsFromBundle(cmd *cobra.Command, b *bundle.Bundle) (*sync.SyncOptions, error) { func (f *syncFlags) syncOptionsFromBundle(cmd *cobra.Command, b *bundle.Bundle) (*sync.SyncOptions, error) {
cacheDir, err := b.CacheDir() cacheDir, err := b.CacheDir(cmd.Context())
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot get bundle cache directory: %w", err) return nil, fmt.Errorf("cannot get bundle cache directory: %w", err)
} }
includes, err := b.GetSyncIncludePatterns() includes, err := b.GetSyncIncludePatterns(cmd.Context())
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot get list of sync includes: %w", err) return nil, fmt.Errorf("cannot get list of sync includes: %w", err)
} }

View File

@ -1,6 +1,7 @@
package cmd package cmd
import ( import (
"context"
"strings" "strings"
"github.com/databricks/cli/cmd/account" "github.com/databricks/cli/cmd/account"
@ -21,8 +22,8 @@ const (
permissionsGroup = "permissions" permissionsGroup = "permissions"
) )
func New() *cobra.Command { func New(ctx context.Context) *cobra.Command {
cli := root.New() cli := root.New(ctx)
// Add account subcommand. // Add account subcommand.
cli.AddCommand(account.New()) cli.AddCommand(account.New())

View File

@ -54,7 +54,7 @@ func TestDefaultConfigureNoInteractive(t *testing.T) {
}) })
os.Stdin = inp os.Stdin = inp
cmd := cmd.New() cmd := cmd.New(ctx)
cmd.SetArgs([]string{"configure", "--token", "--host", "https://host"}) cmd.SetArgs([]string{"configure", "--token", "--host", "https://host"})
err := cmd.ExecuteContext(ctx) err := cmd.ExecuteContext(ctx)
@ -87,7 +87,7 @@ func TestConfigFileFromEnvNoInteractive(t *testing.T) {
t.Cleanup(func() { os.Stdin = oldStdin }) t.Cleanup(func() { os.Stdin = oldStdin })
os.Stdin = inp os.Stdin = inp
cmd := cmd.New() cmd := cmd.New(ctx)
cmd.SetArgs([]string{"configure", "--token", "--host", "https://host"}) cmd.SetArgs([]string{"configure", "--token", "--host", "https://host"})
err := cmd.ExecuteContext(ctx) err := cmd.ExecuteContext(ctx)
@ -116,7 +116,7 @@ func TestCustomProfileConfigureNoInteractive(t *testing.T) {
t.Cleanup(func() { os.Stdin = oldStdin }) t.Cleanup(func() { os.Stdin = oldStdin })
os.Stdin = inp os.Stdin = inp
cmd := cmd.New() cmd := cmd.New(ctx)
cmd.SetArgs([]string{"configure", "--token", "--host", "https://host", "--profile", "CUSTOM"}) cmd.SetArgs([]string{"configure", "--token", "--host", "https://host", "--profile", "CUSTOM"})
err := cmd.ExecuteContext(ctx) err := cmd.ExecuteContext(ctx)

View File

@ -25,13 +25,57 @@ func initProfileFlag(cmd *cobra.Command) {
cmd.RegisterFlagCompletionFunc("profile", databrickscfg.ProfileCompletion) cmd.RegisterFlagCompletionFunc("profile", databrickscfg.ProfileCompletion)
} }
func profileFlagValue(cmd *cobra.Command) (string, bool) {
profileFlag := cmd.Flag("profile")
if profileFlag == nil {
return "", false
}
value := profileFlag.Value.String()
return value, value != ""
}
// Helper function to create an account client or prompt once if the given configuration is not valid.
func accountClientOrPrompt(ctx context.Context, cfg *config.Config, allowPrompt bool) (*databricks.AccountClient, error) {
a, err := databricks.NewAccountClient((*databricks.Config)(cfg))
if err == nil {
err = a.Config.Authenticate(emptyHttpRequest(ctx))
}
prompt := false
if allowPrompt && err != nil && cmdio.IsInteractive(ctx) {
// Prompt to select a profile if the current configuration is not an account client.
prompt = prompt || errors.Is(err, databricks.ErrNotAccountClient)
// Prompt to select a profile if the current configuration doesn't resolve to a credential provider.
prompt = prompt || errors.Is(err, config.ErrCannotConfigureAuth)
}
if !prompt {
// If we are not prompting, we can return early.
return a, err
}
// Try picking a profile dynamically if the current configuration is not valid.
profile, err := askForAccountProfile(ctx)
if err != nil {
return nil, err
}
a, err = databricks.NewAccountClient(&databricks.Config{Profile: profile})
if err == nil {
err = a.Config.Authenticate(emptyHttpRequest(ctx))
if err != nil {
return nil, err
}
}
return a, nil
}
func MustAccountClient(cmd *cobra.Command, args []string) error { func MustAccountClient(cmd *cobra.Command, args []string) error {
cfg := &config.Config{} cfg := &config.Config{}
// command-line flag can specify the profile in use // The command-line profile flag takes precedence over DATABRICKS_CONFIG_PROFILE.
profileFlag := cmd.Flag("profile") profile, hasProfileFlag := profileFlagValue(cmd)
if profileFlag != nil { if hasProfileFlag {
cfg.Profile = profileFlag.Value.String() cfg.Profile = profile
} }
if cfg.Profile == "" { if cfg.Profile == "" {
@ -48,16 +92,8 @@ func MustAccountClient(cmd *cobra.Command, args []string) error {
} }
} }
TRY_AUTH: // or try picking a config profile dynamically allowPrompt := !hasProfileFlag
a, err := databricks.NewAccountClient((*databricks.Config)(cfg)) a, err := accountClientOrPrompt(cmd.Context(), cfg, allowPrompt)
if cmdio.IsInteractive(cmd.Context()) && errors.Is(err, databricks.ErrNotAccountClient) {
profile, err := askForAccountProfile()
if err != nil {
return err
}
cfg = &config.Config{Profile: profile}
goto TRY_AUTH
}
if err != nil { if err != nil {
return err return err
} }
@ -66,13 +102,48 @@ TRY_AUTH: // or try picking a config profile dynamically
return nil return nil
} }
// Helper function to create a workspace client or prompt once if the given configuration is not valid.
func workspaceClientOrPrompt(ctx context.Context, cfg *config.Config, allowPrompt bool) (*databricks.WorkspaceClient, error) {
w, err := databricks.NewWorkspaceClient((*databricks.Config)(cfg))
if err == nil {
err = w.Config.Authenticate(emptyHttpRequest(ctx))
}
prompt := false
if allowPrompt && err != nil && cmdio.IsInteractive(ctx) {
// Prompt to select a profile if the current configuration is not a workspace client.
prompt = prompt || errors.Is(err, databricks.ErrNotWorkspaceClient)
// Prompt to select a profile if the current configuration doesn't resolve to a credential provider.
prompt = prompt || errors.Is(err, config.ErrCannotConfigureAuth)
}
if !prompt {
// If we are not prompting, we can return early.
return w, err
}
// Try picking a profile dynamically if the current configuration is not valid.
profile, err := askForWorkspaceProfile(ctx)
if err != nil {
return nil, err
}
w, err = databricks.NewWorkspaceClient(&databricks.Config{Profile: profile})
if err == nil {
err = w.Config.Authenticate(emptyHttpRequest(ctx))
if err != nil {
return nil, err
}
}
return w, nil
}
func MustWorkspaceClient(cmd *cobra.Command, args []string) error { func MustWorkspaceClient(cmd *cobra.Command, args []string) error {
cfg := &config.Config{} cfg := &config.Config{}
// command-line flag takes precedence over environment variable // The command-line profile flag takes precedence over DATABRICKS_CONFIG_PROFILE.
profileFlag := cmd.Flag("profile") profile, hasProfileFlag := profileFlagValue(cmd)
if profileFlag != nil { if hasProfileFlag {
cfg.Profile = profileFlag.Value.String() cfg.Profile = profile
} }
// try configuring a bundle // try configuring a bundle
@ -87,24 +158,13 @@ func MustWorkspaceClient(cmd *cobra.Command, args []string) error {
cfg = currentBundle.WorkspaceClient().Config cfg = currentBundle.WorkspaceClient().Config
} }
TRY_AUTH: // or try picking a config profile dynamically allowPrompt := !hasProfileFlag
w, err := workspaceClientOrPrompt(cmd.Context(), cfg, allowPrompt)
if err != nil {
return err
}
ctx := cmd.Context() ctx := cmd.Context()
w, err := databricks.NewWorkspaceClient((*databricks.Config)(cfg))
if err != nil {
return err
}
err = w.Config.Authenticate(emptyHttpRequest(ctx))
if cmdio.IsInteractive(ctx) && errors.Is(err, config.ErrCannotConfigureAuth) {
profile, err := askForWorkspaceProfile()
if err != nil {
return err
}
cfg = &config.Config{Profile: profile}
goto TRY_AUTH
}
if err != nil {
return err
}
ctx = context.WithValue(ctx, &workspaceClient, w) ctx = context.WithValue(ctx, &workspaceClient, w)
cmd.SetContext(ctx) cmd.SetContext(ctx)
return nil return nil
@ -121,7 +181,7 @@ func transformLoadError(path string, err error) error {
return err return err
} }
func askForWorkspaceProfile() (string, error) { func askForWorkspaceProfile(ctx context.Context) (string, error) {
path, err := databrickscfg.GetPath() path, err := databrickscfg.GetPath()
if err != nil { if err != nil {
return "", fmt.Errorf("cannot determine Databricks config file path: %w", err) return "", fmt.Errorf("cannot determine Databricks config file path: %w", err)
@ -136,7 +196,7 @@ func askForWorkspaceProfile() (string, error) {
case 1: case 1:
return profiles[0].Name, nil return profiles[0].Name, nil
} }
i, _, err := (&promptui.Select{ i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
Label: fmt.Sprintf("Workspace profiles defined in %s", file), Label: fmt.Sprintf("Workspace profiles defined in %s", file),
Items: profiles, Items: profiles,
Searcher: profiles.SearchCaseInsensitive, Searcher: profiles.SearchCaseInsensitive,
@ -147,16 +207,14 @@ func askForWorkspaceProfile() (string, error) {
Inactive: `{{.Name}}`, Inactive: `{{.Name}}`,
Selected: `{{ "Using workspace profile" | faint }}: {{ .Name | bold }}`, Selected: `{{ "Using workspace profile" | faint }}: {{ .Name | bold }}`,
}, },
Stdin: os.Stdin, })
Stdout: os.Stderr,
}).Run()
if err != nil { if err != nil {
return "", err return "", err
} }
return profiles[i].Name, nil return profiles[i].Name, nil
} }
func askForAccountProfile() (string, error) { func askForAccountProfile(ctx context.Context) (string, error) {
path, err := databrickscfg.GetPath() path, err := databrickscfg.GetPath()
if err != nil { if err != nil {
return "", fmt.Errorf("cannot determine Databricks config file path: %w", err) return "", fmt.Errorf("cannot determine Databricks config file path: %w", err)
@ -171,7 +229,7 @@ func askForAccountProfile() (string, error) {
case 1: case 1:
return profiles[0].Name, nil return profiles[0].Name, nil
} }
i, _, err := (&promptui.Select{ i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
Label: fmt.Sprintf("Account profiles defined in %s", file), Label: fmt.Sprintf("Account profiles defined in %s", file),
Items: profiles, Items: profiles,
Searcher: profiles.SearchCaseInsensitive, Searcher: profiles.SearchCaseInsensitive,
@ -182,9 +240,7 @@ func askForAccountProfile() (string, error) {
Inactive: `{{.Name}}`, Inactive: `{{.Name}}`,
Selected: `{{ "Using account profile" | faint }}: {{ .Name | bold }}`, Selected: `{{ "Using account profile" | faint }}: {{ .Name | bold }}`,
}, },
Stdin: os.Stdin, })
Stdout: os.Stderr,
}).Run()
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -2,9 +2,16 @@ package root
import ( import (
"context" "context"
"os"
"path/filepath"
"testing" "testing"
"time"
"github.com/databricks/cli/internal/testutil"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/databricks-sdk-go/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestEmptyHttpRequest(t *testing.T) { func TestEmptyHttpRequest(t *testing.T) {
@ -12,3 +19,165 @@ func TestEmptyHttpRequest(t *testing.T) {
req := emptyHttpRequest(ctx) req := emptyHttpRequest(ctx)
assert.Equal(t, req.Context(), ctx) assert.Equal(t, req.Context(), ctx)
} }
type promptFn func(ctx context.Context, cfg *config.Config, retry bool) (any, error)
var accountPromptFn = func(ctx context.Context, cfg *config.Config, retry bool) (any, error) {
return accountClientOrPrompt(ctx, cfg, retry)
}
var workspacePromptFn = func(ctx context.Context, cfg *config.Config, retry bool) (any, error) {
return workspaceClientOrPrompt(ctx, cfg, retry)
}
func expectPrompts(t *testing.T, fn promptFn, config *config.Config) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
// Channel to pass errors from the prompting function back to the test.
errch := make(chan error, 1)
ctx, io := cmdio.SetupTest(ctx)
go func() {
defer close(errch)
defer cancel()
_, err := fn(ctx, config, true)
errch <- err
}()
// Expect a prompt
line, _, err := io.Stderr.ReadLine()
if assert.NoError(t, err, "Expected to read a line from stderr") {
assert.Contains(t, string(line), "Search:")
} else {
// If there was an error reading from stderr, the prompting function must have terminated early.
assert.NoError(t, <-errch)
}
}
func expectReturns(t *testing.T, fn promptFn, config *config.Config) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
ctx, _ = cmdio.SetupTest(ctx)
client, err := fn(ctx, config, true)
require.NoError(t, err)
require.NotNil(t, client)
}
func TestAccountClientOrPrompt(t *testing.T) {
testutil.CleanupEnvironment(t)
dir := t.TempDir()
configFile := filepath.Join(dir, ".databrickscfg")
err := os.WriteFile(
configFile,
[]byte(`
[account-1111]
host = https://accounts.azuredatabricks.net/
account_id = 1111
token = foobar
[account-1112]
host = https://accounts.azuredatabricks.net/
account_id = 1112
token = foobar
`),
0755)
require.NoError(t, err)
t.Setenv("DATABRICKS_CONFIG_FILE", configFile)
t.Setenv("PATH", "/nothing")
t.Run("Prompt if nothing is specified", func(t *testing.T) {
expectPrompts(t, accountPromptFn, &config.Config{})
})
t.Run("Prompt if a workspace host is specified", func(t *testing.T) {
expectPrompts(t, accountPromptFn, &config.Config{
Host: "https://adb-1234567.89.azuredatabricks.net/",
AccountID: "1234",
Token: "foobar",
})
})
t.Run("Prompt if account ID is not specified", func(t *testing.T) {
expectPrompts(t, accountPromptFn, &config.Config{
Host: "https://accounts.azuredatabricks.net/",
Token: "foobar",
})
})
t.Run("Prompt if no credential provider can be configured", func(t *testing.T) {
expectPrompts(t, accountPromptFn, &config.Config{
Host: "https://accounts.azuredatabricks.net/",
AccountID: "1234",
})
})
t.Run("Returns if configuration is valid", func(t *testing.T) {
expectReturns(t, accountPromptFn, &config.Config{
Host: "https://accounts.azuredatabricks.net/",
AccountID: "1234",
Token: "foobar",
})
})
t.Run("Returns if a valid profile is specified", func(t *testing.T) {
expectReturns(t, accountPromptFn, &config.Config{
Profile: "account-1111",
})
})
}
func TestWorkspaceClientOrPrompt(t *testing.T) {
testutil.CleanupEnvironment(t)
dir := t.TempDir()
configFile := filepath.Join(dir, ".databrickscfg")
err := os.WriteFile(
configFile,
[]byte(`
[workspace-1111]
host = https://adb-1111.11.azuredatabricks.net/
token = foobar
[workspace-1112]
host = https://adb-1112.12.azuredatabricks.net/
token = foobar
`),
0755)
require.NoError(t, err)
t.Setenv("DATABRICKS_CONFIG_FILE", configFile)
t.Setenv("PATH", "/nothing")
t.Run("Prompt if nothing is specified", func(t *testing.T) {
expectPrompts(t, workspacePromptFn, &config.Config{})
})
t.Run("Prompt if an account host is specified", func(t *testing.T) {
expectPrompts(t, workspacePromptFn, &config.Config{
Host: "https://accounts.azuredatabricks.net/",
AccountID: "1234",
Token: "foobar",
})
})
t.Run("Prompt if no credential provider can be configured", func(t *testing.T) {
expectPrompts(t, workspacePromptFn, &config.Config{
Host: "https://adb-1111.11.azuredatabricks.net/",
})
})
t.Run("Returns if configuration is valid", func(t *testing.T) {
expectReturns(t, workspacePromptFn, &config.Config{
Host: "https://adb-1111.11.azuredatabricks.net/",
Token: "foobar",
})
})
t.Run("Returns if a valid profile is specified", func(t *testing.T) {
expectReturns(t, workspacePromptFn, &config.Config{
Profile: "workspace-1111",
})
})
}

View File

@ -2,17 +2,15 @@ package root
import ( import (
"context" "context"
"os"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config/mutator" "github.com/databricks/cli/bundle/config/mutator"
"github.com/databricks/cli/bundle/env"
envlib "github.com/databricks/cli/libs/env"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
) )
const envName = "DATABRICKS_BUNDLE_ENV"
const targetName = "DATABRICKS_BUNDLE_TARGET"
// getTarget returns the name of the target to operate in. // getTarget returns the name of the target to operate in.
func getTarget(cmd *cobra.Command) (value string) { func getTarget(cmd *cobra.Command) (value string) {
// The command line flag takes precedence. // The command line flag takes precedence.
@ -33,13 +31,7 @@ func getTarget(cmd *cobra.Command) (value string) {
} }
// If it's not set, use the environment variable. // If it's not set, use the environment variable.
target := os.Getenv(targetName) target, _ := env.Target(cmd.Context())
// If target env is not set with a new variable, try to check for old variable name
// TODO: remove when environments section is not supported anymore
if target == "" {
target = os.Getenv(envName)
}
return target return target
} }
@ -54,7 +46,7 @@ func getProfile(cmd *cobra.Command) (value string) {
} }
// If it's not set, use the environment variable. // If it's not set, use the environment variable.
return os.Getenv("DATABRICKS_CONFIG_PROFILE") return envlib.Get(cmd.Context(), "DATABRICKS_CONFIG_PROFILE")
} }
// loadBundle loads the bundle configuration and applies default mutators. // loadBundle loads the bundle configuration and applies default mutators.

View File

@ -9,6 +9,7 @@ import (
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/internal/testutil"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -56,6 +57,8 @@ func setup(t *testing.T, cmd *cobra.Command, host string) *bundle.Bundle {
} }
func TestBundleConfigureDefault(t *testing.T) { func TestBundleConfigureDefault(t *testing.T) {
testutil.CleanupEnvironment(t)
cmd := emptyCommand(t) cmd := emptyCommand(t)
b := setup(t, cmd, "https://x.com") b := setup(t, cmd, "https://x.com")
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
@ -64,6 +67,8 @@ func TestBundleConfigureDefault(t *testing.T) {
} }
func TestBundleConfigureWithMultipleMatches(t *testing.T) { func TestBundleConfigureWithMultipleMatches(t *testing.T) {
testutil.CleanupEnvironment(t)
cmd := emptyCommand(t) cmd := emptyCommand(t)
b := setup(t, cmd, "https://a.com") b := setup(t, cmd, "https://a.com")
assert.Panics(t, func() { assert.Panics(t, func() {
@ -72,6 +77,8 @@ func TestBundleConfigureWithMultipleMatches(t *testing.T) {
} }
func TestBundleConfigureWithNonExistentProfileFlag(t *testing.T) { func TestBundleConfigureWithNonExistentProfileFlag(t *testing.T) {
testutil.CleanupEnvironment(t)
cmd := emptyCommand(t) cmd := emptyCommand(t)
cmd.Flag("profile").Value.Set("NOEXIST") cmd.Flag("profile").Value.Set("NOEXIST")
@ -82,6 +89,8 @@ func TestBundleConfigureWithNonExistentProfileFlag(t *testing.T) {
} }
func TestBundleConfigureWithMismatchedProfile(t *testing.T) { func TestBundleConfigureWithMismatchedProfile(t *testing.T) {
testutil.CleanupEnvironment(t)
cmd := emptyCommand(t) cmd := emptyCommand(t)
cmd.Flag("profile").Value.Set("PROFILE-1") cmd.Flag("profile").Value.Set("PROFILE-1")
@ -92,6 +101,8 @@ func TestBundleConfigureWithMismatchedProfile(t *testing.T) {
} }
func TestBundleConfigureWithCorrectProfile(t *testing.T) { func TestBundleConfigureWithCorrectProfile(t *testing.T) {
testutil.CleanupEnvironment(t)
cmd := emptyCommand(t) cmd := emptyCommand(t)
cmd.Flag("profile").Value.Set("PROFILE-1") cmd.Flag("profile").Value.Set("PROFILE-1")
@ -102,10 +113,8 @@ func TestBundleConfigureWithCorrectProfile(t *testing.T) {
} }
func TestBundleConfigureWithMismatchedProfileEnvVariable(t *testing.T) { func TestBundleConfigureWithMismatchedProfileEnvVariable(t *testing.T) {
testutil.CleanupEnvironment(t)
t.Setenv("DATABRICKS_CONFIG_PROFILE", "PROFILE-1") t.Setenv("DATABRICKS_CONFIG_PROFILE", "PROFILE-1")
t.Cleanup(func() {
t.Setenv("DATABRICKS_CONFIG_PROFILE", "")
})
cmd := emptyCommand(t) cmd := emptyCommand(t)
b := setup(t, cmd, "https://x.com") b := setup(t, cmd, "https://x.com")
@ -115,10 +124,8 @@ func TestBundleConfigureWithMismatchedProfileEnvVariable(t *testing.T) {
} }
func TestBundleConfigureWithProfileFlagAndEnvVariable(t *testing.T) { func TestBundleConfigureWithProfileFlagAndEnvVariable(t *testing.T) {
testutil.CleanupEnvironment(t)
t.Setenv("DATABRICKS_CONFIG_PROFILE", "NOEXIST") t.Setenv("DATABRICKS_CONFIG_PROFILE", "NOEXIST")
t.Cleanup(func() {
t.Setenv("DATABRICKS_CONFIG_PROFILE", "")
})
cmd := emptyCommand(t) cmd := emptyCommand(t)
cmd.Flag("profile").Value.Set("PROFILE-1") cmd.Flag("profile").Value.Set("PROFILE-1")

View File

@ -1,9 +1,8 @@
package root package root
import ( import (
"os"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/flags"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -21,7 +20,7 @@ func initOutputFlag(cmd *cobra.Command) *outputFlag {
// Configure defaults from environment, if applicable. // Configure defaults from environment, if applicable.
// If the provided value is invalid it is ignored. // If the provided value is invalid it is ignored.
if v, ok := os.LookupEnv(envOutputFormat); ok { if v, ok := env.Lookup(cmd.Context(), envOutputFormat); ok {
f.output.Set(v) f.output.Set(v)
} }

View File

@ -5,9 +5,9 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"os"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/flags"
"github.com/databricks/cli/libs/log" "github.com/databricks/cli/libs/log"
"github.com/fatih/color" "github.com/fatih/color"
@ -126,13 +126,13 @@ func initLogFlags(cmd *cobra.Command) *logFlags {
// Configure defaults from environment, if applicable. // Configure defaults from environment, if applicable.
// If the provided value is invalid it is ignored. // If the provided value is invalid it is ignored.
if v, ok := os.LookupEnv(envLogFile); ok { if v, ok := env.Lookup(cmd.Context(), envLogFile); ok {
f.file.Set(v) f.file.Set(v)
} }
if v, ok := os.LookupEnv(envLogLevel); ok { if v, ok := env.Lookup(cmd.Context(), envLogLevel); ok {
f.level.Set(v) f.level.Set(v)
} }
if v, ok := os.LookupEnv(envLogFormat); ok { if v, ok := env.Lookup(cmd.Context(), envLogFormat); ok {
f.output.Set(v) f.output.Set(v)
} }

View File

@ -6,6 +6,7 @@ import (
"os" "os"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/flags"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/term" "golang.org/x/term"
@ -51,7 +52,7 @@ func initProgressLoggerFlag(cmd *cobra.Command, logFlags *logFlags) *progressLog
// Configure defaults from environment, if applicable. // Configure defaults from environment, if applicable.
// If the provided value is invalid it is ignored. // If the provided value is invalid it is ignored.
if v, ok := os.LookupEnv(envProgressFormat); ok { if v, ok := env.Lookup(cmd.Context(), envProgressFormat); ok {
f.Set(v) f.Set(v)
} }

View File

@ -14,7 +14,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
func New() *cobra.Command { func New(ctx context.Context) *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "databricks", Use: "databricks",
Short: "Databricks CLI", Short: "Databricks CLI",
@ -30,6 +30,10 @@ func New() *cobra.Command {
SilenceErrors: true, SilenceErrors: true,
} }
// Pass the context along through the command during initialization.
// It will be overwritten when the command is executed.
cmd.SetContext(ctx)
// Initialize flags // Initialize flags
logFlags := initLogFlags(cmd) logFlags := initLogFlags(cmd)
progressLoggerFlag := initProgressLoggerFlag(cmd, logFlags) progressLoggerFlag := initProgressLoggerFlag(cmd, logFlags)

View File

@ -2,8 +2,8 @@ package root
import ( import (
"context" "context"
"os"
"github.com/databricks/cli/libs/env"
"github.com/databricks/databricks-sdk-go/useragent" "github.com/databricks/databricks-sdk-go/useragent"
) )
@ -16,7 +16,7 @@ const upstreamKey = "upstream"
const upstreamVersionKey = "upstream-version" const upstreamVersionKey = "upstream-version"
func withUpstreamInUserAgent(ctx context.Context) context.Context { func withUpstreamInUserAgent(ctx context.Context) context.Context {
value := os.Getenv(upstreamEnvVar) value := env.Get(ctx, upstreamEnvVar)
if value == "" { if value == "" {
return ctx return ctx
} }
@ -24,7 +24,7 @@ func withUpstreamInUserAgent(ctx context.Context) context.Context {
ctx = useragent.InContext(ctx, upstreamKey, value) ctx = useragent.InContext(ctx, upstreamKey, value)
// Include upstream version as well, if set. // Include upstream version as well, if set.
value = os.Getenv(upstreamVersionEnvVar) value = env.Get(ctx, upstreamVersionEnvVar)
if value == "" { if value == "" {
return ctx return ctx
} }

View File

@ -30,12 +30,12 @@ func (f *syncFlags) syncOptionsFromBundle(cmd *cobra.Command, args []string, b *
return nil, fmt.Errorf("SRC and DST are not configurable in the context of a bundle") return nil, fmt.Errorf("SRC and DST are not configurable in the context of a bundle")
} }
cacheDir, err := b.CacheDir() cacheDir, err := b.CacheDir(cmd.Context())
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot get bundle cache directory: %w", err) return nil, fmt.Errorf("cannot get bundle cache directory: %w", err)
} }
includes, err := b.GetSyncIncludePatterns() includes, err := b.GetSyncIncludePatterns(cmd.Context())
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot get list of sync includes: %w", err) return nil, fmt.Errorf("cannot get list of sync includes: %w", err)
} }

View File

@ -10,7 +10,7 @@ func New() *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "version", Use: "version",
Args: cobra.NoArgs, Args: cobra.NoArgs,
Short: "Retrieve information about the current version of this CLI",
Annotations: map[string]string{ Annotations: map[string]string{
"template": "Databricks CLI v{{.Version}}\n", "template": "Databricks CLI v{{.Version}}\n",
}, },

2
go.mod
View File

@ -4,7 +4,7 @@ go 1.21
require ( require (
github.com/briandowns/spinner v1.23.0 // Apache 2.0 github.com/briandowns/spinner v1.23.0 // Apache 2.0
github.com/databricks/databricks-sdk-go v0.19.0 // Apache 2.0 github.com/databricks/databricks-sdk-go v0.19.1 // Apache 2.0
github.com/fatih/color v1.15.0 // MIT github.com/fatih/color v1.15.0 // MIT
github.com/ghodss/yaml v1.0.0 // MIT + NOTICE github.com/ghodss/yaml v1.0.0 // MIT + NOTICE
github.com/google/uuid v1.3.0 // BSD-3-Clause github.com/google/uuid v1.3.0 // BSD-3-Clause

4
go.sum
View File

@ -36,8 +36,8 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH
github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/databricks/databricks-sdk-go v0.19.0 h1:Xh5A90/+8ehW7fTqoQbQK5xZu7a/akv3Xwv8UdWB4GU= github.com/databricks/databricks-sdk-go v0.19.1 h1:hP7xZb+Hd8n0grnEcf2FOMn6lWox7vp5KAan3D2hnzM=
github.com/databricks/databricks-sdk-go v0.19.0/go.mod h1:Bt/3i3ry/rQdE6Y+psvkAENlp+LzJHaQK5PsLIstQb4= github.com/databricks/databricks-sdk-go v0.19.1/go.mod h1:Bt/3i3ry/rQdE6Y+psvkAENlp+LzJHaQK5PsLIstQb4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

View File

@ -0,0 +1,21 @@
{
"properties": {
"project_name": {
"type": "string",
"default": "my_test_code",
"description": "Unique name for this project"
},
"spark_version": {
"type": "string",
"description": "Spark version used for job cluster"
},
"node_type_id": {
"type": "string",
"description": "Node type id for job cluster"
},
"unique_id": {
"type": "string",
"description": "Unique ID for job name"
}
}
}

View File

@ -0,0 +1,24 @@
bundle:
name: wheel-task
workspace:
root_path: "~/.bundle/{{.unique_id}}"
resources:
jobs:
some_other_job:
name: "[${bundle.target}] Test Wheel Job {{.unique_id}}"
tasks:
- task_key: TestTask
new_cluster:
num_workers: 1
spark_version: "{{.spark_version}}"
node_type_id: "{{.node_type_id}}"
python_wheel_task:
package_name: my_test_code
entry_point: run
parameters:
- "one"
- "two"
libraries:
- whl: ./dist/*.whl

View File

@ -0,0 +1,15 @@
from setuptools import setup, find_packages
import {{.project_name}}
setup(
name="{{.project_name}}",
version={{.project_name}}.__version__,
author={{.project_name}}.__author__,
url="https://databricks.com",
author_email="john.doe@databricks.com",
description="my example wheel",
packages=find_packages(include=["{{.project_name}}"]),
entry_points={"group1": "run={{.project_name}}.__main__:main"},
install_requires=["setuptools"],
)

View File

@ -0,0 +1,2 @@
__version__ = "0.0.1"
__author__ = "Databricks"

View File

@ -0,0 +1,16 @@
"""
The entry point of the Python Wheel
"""
import sys
def main():
# This method will print the provided arguments
print("Hello from my func")
print("Got arguments:")
print(sys.argv)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,70 @@
package bundle
import (
"context"
"encoding/json"
"os"
"path/filepath"
"strings"
"testing"
"github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/internal"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/flags"
"github.com/databricks/cli/libs/template"
)
func initTestTemplate(t *testing.T, templateName string, config map[string]any) (string, error) {
templateRoot := filepath.Join("bundles", templateName)
bundleRoot := t.TempDir()
configFilePath, err := writeConfigFile(t, config)
if err != nil {
return "", err
}
ctx := root.SetWorkspaceClient(context.Background(), nil)
cmd := cmdio.NewIO(flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "bundles")
ctx = cmdio.InContext(ctx, cmd)
err = template.Materialize(ctx, configFilePath, templateRoot, bundleRoot)
return bundleRoot, err
}
func writeConfigFile(t *testing.T, config map[string]any) (string, error) {
bytes, err := json.Marshal(config)
if err != nil {
return "", err
}
dir := t.TempDir()
filepath := filepath.Join(dir, "config.json")
t.Log("Configuration for template: ", string(bytes))
err = os.WriteFile(filepath, bytes, 0644)
return filepath, err
}
func deployBundle(t *testing.T, path string) error {
t.Setenv("BUNDLE_ROOT", path)
c := internal.NewCobraTestRunner(t, "bundle", "deploy", "--force-lock")
_, _, err := c.Run()
return err
}
func runResource(t *testing.T, path string, key string) (string, error) {
ctx := context.Background()
ctx = cmdio.NewContext(ctx, cmdio.Default())
c := internal.NewCobraTestRunnerWithContext(t, ctx, "bundle", "run", key)
stdout, _, err := c.Run()
return stdout.String(), err
}
func destroyBundle(t *testing.T, path string) error {
t.Setenv("BUNDLE_ROOT", path)
c := internal.NewCobraTestRunner(t, "bundle", "destroy", "--auto-approve")
_, _, err := c.Run()
return err
}

View File

@ -0,0 +1,43 @@
package bundle
import (
"testing"
"github.com/databricks/cli/internal"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
func TestAccPythonWheelTaskDeployAndRun(t *testing.T) {
env := internal.GetEnvOrSkipTest(t, "CLOUD_ENV")
t.Log(env)
var nodeTypeId string
if env == "gcp" {
nodeTypeId = "n1-standard-4"
} else if env == "aws" {
nodeTypeId = "i3.xlarge"
} else {
nodeTypeId = "Standard_DS4_v2"
}
bundleRoot, err := initTestTemplate(t, "python_wheel_task", map[string]any{
"node_type_id": nodeTypeId,
"unique_id": uuid.New().String(),
"spark_version": "13.2.x-snapshot-scala2.12",
})
require.NoError(t, err)
err = deployBundle(t, bundleRoot)
require.NoError(t, err)
t.Cleanup(func() {
destroyBundle(t, bundleRoot)
})
out, err := runResource(t, bundleRoot, "some_other_job")
require.NoError(t, err)
require.Contains(t, out, "Hello from my func")
require.Contains(t, out, "Got arguments:")
require.Contains(t, out, "['python', 'one', 'two']")
}

View File

@ -58,6 +58,8 @@ type cobraTestRunner struct {
stdout bytes.Buffer stdout bytes.Buffer
stderr bytes.Buffer stderr bytes.Buffer
ctx context.Context
// Line-by-line output. // Line-by-line output.
// Background goroutines populate these channels by reading from stdout/stderr pipes. // Background goroutines populate these channels by reading from stdout/stderr pipes.
stdoutLines <-chan string stdoutLines <-chan string
@ -116,7 +118,7 @@ func (t *cobraTestRunner) RunBackground() {
var stdoutW, stderrW io.WriteCloser var stdoutW, stderrW io.WriteCloser
stdoutR, stdoutW = io.Pipe() stdoutR, stdoutW = io.Pipe()
stderrR, stderrW = io.Pipe() stderrR, stderrW = io.Pipe()
root := cmd.New() root := cmd.New(context.Background())
root.SetOut(stdoutW) root.SetOut(stdoutW)
root.SetErr(stderrW) root.SetErr(stderrW)
root.SetArgs(t.args) root.SetArgs(t.args)
@ -128,7 +130,7 @@ func (t *cobraTestRunner) RunBackground() {
t.registerFlagCleanup(root) t.registerFlagCleanup(root)
errch := make(chan error) errch := make(chan error)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(t.ctx)
// Tee stdout/stderr to buffers. // Tee stdout/stderr to buffers.
stdoutR = io.TeeReader(stdoutR, &t.stdout) stdoutR = io.TeeReader(stdoutR, &t.stdout)
@ -234,6 +236,15 @@ func (c *cobraTestRunner) Eventually(condition func() bool, waitFor time.Duratio
func NewCobraTestRunner(t *testing.T, args ...string) *cobraTestRunner { func NewCobraTestRunner(t *testing.T, args ...string) *cobraTestRunner {
return &cobraTestRunner{ return &cobraTestRunner{
T: t, T: t,
ctx: context.Background(),
args: args,
}
}
func NewCobraTestRunnerWithContext(t *testing.T, ctx context.Context, args ...string) *cobraTestRunner {
return &cobraTestRunner{
T: t,
ctx: ctx,
args: args, args: args,
} }
} }

37
internal/testutil/env.go Normal file
View File

@ -0,0 +1,37 @@
package testutil
import (
"os"
"runtime"
"strings"
"testing"
)
// CleanupEnvironment sets up a pristine environment containing only $PATH and $HOME.
// The original environment is restored upon test completion.
// Note: use of this function is incompatible with parallel execution.
func CleanupEnvironment(t *testing.T) {
// Restore environment when test finishes.
environ := os.Environ()
t.Cleanup(func() {
// Restore original environment.
for _, kv := range environ {
kvs := strings.SplitN(kv, "=", 2)
os.Setenv(kvs[0], kvs[1])
}
})
path := os.Getenv("PATH")
pwd := os.Getenv("PWD")
os.Clearenv()
// We use t.Setenv instead of os.Setenv because the former actively
// prevents a test being run with t.Parallel. Modifying the environment
// within a test is not compatible with running tests in parallel
// because of isolation; the environment is scoped to the process.
t.Setenv("PATH", path)
t.Setenv("HOME", pwd)
if runtime.GOOS == "windows" {
t.Setenv("USERPROFILE", pwd)
}
}

View File

@ -205,6 +205,13 @@ func Prompt(ctx context.Context) *promptui.Prompt {
} }
} }
func RunSelect(ctx context.Context, prompt *promptui.Select) (int, string, error) {
c := fromContext(ctx)
prompt.Stdin = io.NopCloser(c.in)
prompt.Stdout = nopWriteCloser{c.err}
return prompt.Run()
}
func (c *cmdIO) simplePrompt(label string) *promptui.Prompt { func (c *cmdIO) simplePrompt(label string) *promptui.Prompt {
return &promptui.Prompt{ return &promptui.Prompt{
Label: label, Label: label,

View File

@ -10,6 +10,7 @@ import (
"strings" "strings"
"github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/flags"
"github.com/manifoldco/promptui"
) )
// This is the interface for all io interactions with a user // This is the interface for all io interactions with a user
@ -104,6 +105,36 @@ func AskYesOrNo(ctx context.Context, question string) (bool, error) {
return false, nil return false, nil
} }
func AskSelect(ctx context.Context, question string, choices []string) (string, error) {
logger, ok := FromContext(ctx)
if !ok {
logger = Default()
}
return logger.AskSelect(question, choices)
}
func (l *Logger) AskSelect(question string, choices []string) (string, error) {
if l.Mode == flags.ModeJson {
return "", fmt.Errorf("question prompts are not supported in json mode")
}
prompt := promptui.Select{
Label: question,
Items: choices,
HideHelp: true,
Templates: &promptui.SelectTemplates{
Label: "{{.}}: ",
Selected: fmt.Sprintf("%s: {{.}}", question),
},
}
_, ans, err := prompt.Run()
if err != nil {
return "", err
}
return ans, nil
}
func (l *Logger) Ask(question string, defaultVal string) (string, error) { func (l *Logger) Ask(question string, defaultVal string) (string, error) {
if l.Mode == flags.ModeJson { if l.Mode == flags.ModeJson {
return "", fmt.Errorf("question prompts are not supported in json mode") return "", fmt.Errorf("question prompts are not supported in json mode")

View File

@ -1,6 +1,7 @@
package cmdio package cmdio
import ( import (
"context"
"testing" "testing"
"github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/flags"
@ -12,3 +13,11 @@ func TestAskFailedInJsonMode(t *testing.T) {
_, err := l.Ask("What is your spirit animal?", "") _, err := l.Ask("What is your spirit animal?", "")
assert.ErrorContains(t, err, "question prompts are not supported in json mode") assert.ErrorContains(t, err, "question prompts are not supported in json mode")
} }
func TestAskChoiceFailsInJsonMode(t *testing.T) {
l := NewLogger(flags.ModeJson)
ctx := NewContext(context.Background(), l)
_, err := AskSelect(ctx, "what is a question?", []string{"b", "c", "a"})
assert.EqualError(t, err, "question prompts are not supported in json mode")
}

46
libs/cmdio/testing.go Normal file
View File

@ -0,0 +1,46 @@
package cmdio
import (
"bufio"
"context"
"io"
)
type Test struct {
Done context.CancelFunc
Stdin *bufio.Writer
Stdout *bufio.Reader
Stderr *bufio.Reader
}
func SetupTest(ctx context.Context) (context.Context, *Test) {
rin, win := io.Pipe()
rout, wout := io.Pipe()
rerr, werr := io.Pipe()
cmdio := &cmdIO{
interactive: true,
in: rin,
out: wout,
err: werr,
}
ctx, cancel := context.WithCancel(ctx)
ctx = InContext(ctx, cmdio)
// Wait for context to be done, so we can drain stdin and close the pipes.
go func() {
<-ctx.Done()
rin.Close()
wout.Close()
werr.Close()
}()
return ctx, &Test{
Done: cancel,
Stdin: bufio.NewWriter(win),
Stdout: bufio.NewReader(rout),
Stderr: bufio.NewReader(rerr),
}
}

63
libs/env/context.go vendored Normal file
View File

@ -0,0 +1,63 @@
package env
import (
"context"
"os"
)
var envContextKey int
func copyMap(m map[string]string) map[string]string {
out := make(map[string]string, len(m))
for k, v := range m {
out[k] = v
}
return out
}
func getMap(ctx context.Context) map[string]string {
if ctx == nil {
return nil
}
m, ok := ctx.Value(&envContextKey).(map[string]string)
if !ok {
return nil
}
return m
}
func setMap(ctx context.Context, m map[string]string) context.Context {
return context.WithValue(ctx, &envContextKey, m)
}
// Lookup key in the context or the the environment.
// Context has precedence.
func Lookup(ctx context.Context, key string) (string, bool) {
m := getMap(ctx)
// Return if the key is set in the context.
v, ok := m[key]
if ok {
return v, true
}
// Fall back to the environment.
return os.LookupEnv(key)
}
// Get key from the context or the environment.
// Context has precedence.
func Get(ctx context.Context, key string) string {
v, _ := Lookup(ctx, key)
return v
}
// Set key on the context.
//
// Note: this does NOT mutate the processes' actual environment variables.
// It is only visible to other code that uses this package.
func Set(ctx context.Context, key, value string) context.Context {
m := copyMap(getMap(ctx))
m[key] = value
return setMap(ctx, m)
}

41
libs/env/context_test.go vendored Normal file
View File

@ -0,0 +1,41 @@
package env
import (
"context"
"testing"
"github.com/databricks/cli/internal/testutil"
"github.com/stretchr/testify/assert"
)
func TestContext(t *testing.T) {
testutil.CleanupEnvironment(t)
t.Setenv("FOO", "bar")
ctx0 := context.Background()
// Get
assert.Equal(t, "bar", Get(ctx0, "FOO"))
assert.Equal(t, "", Get(ctx0, "dontexist"))
// Lookup
v, ok := Lookup(ctx0, "FOO")
assert.True(t, ok)
assert.Equal(t, "bar", v)
v, ok = Lookup(ctx0, "dontexist")
assert.False(t, ok)
assert.Equal(t, "", v)
// Set and get new context.
// Verify that the previous context remains unchanged.
ctx1 := Set(ctx0, "FOO", "baz")
assert.Equal(t, "baz", Get(ctx1, "FOO"))
assert.Equal(t, "bar", Get(ctx0, "FOO"))
// Set and get new context.
// Verify that the previous contexts remains unchanged.
ctx2 := Set(ctx1, "FOO", "qux")
assert.Equal(t, "qux", Get(ctx2, "FOO"))
assert.Equal(t, "baz", Get(ctx1, "FOO"))
assert.Equal(t, "bar", Get(ctx0, "FOO"))
}

7
libs/env/pkg.go vendored Normal file
View File

@ -0,0 +1,7 @@
package env
// The env package provides functions for working with environment variables
// and allowing for overrides via the context.Context. This is useful for
// testing where tainting a processes' environment is at odds with parallelism.
// Use of a context.Context to store variable overrides means tests can be
// parallelized without worrying about environment variable interference.

113
libs/jsonschema/instance.go Normal file
View File

@ -0,0 +1,113 @@
package jsonschema
import (
"encoding/json"
"fmt"
"os"
"slices"
)
// Load a JSON document and validate it against the JSON schema. Instance here
// refers to a JSON document. see: https://json-schema.org/draft/2020-12/json-schema-core.html#name-instance
func (s *Schema) LoadInstance(path string) (map[string]any, error) {
instance := make(map[string]any)
b, err := os.ReadFile(path)
if err != nil {
return nil, err
}
err = json.Unmarshal(b, &instance)
if err != nil {
return nil, err
}
// The default JSON unmarshaler parses untyped number values as float64.
// We convert integer properties from float64 to int64 here.
for name, v := range instance {
propertySchema, ok := s.Properties[name]
if !ok {
continue
}
if propertySchema.Type != IntegerType {
continue
}
integerValue, err := toInteger(v)
if err != nil {
return nil, fmt.Errorf("failed to parse property %s: %w", name, err)
}
instance[name] = integerValue
}
return instance, s.ValidateInstance(instance)
}
func (s *Schema) ValidateInstance(instance map[string]any) error {
for _, fn := range []func(map[string]any) error{
s.validateAdditionalProperties,
s.validateEnum,
s.validateRequired,
s.validateTypes,
} {
err := fn(instance)
if err != nil {
return err
}
}
return nil
}
// If additional properties is set to false, this function validates instance only
// contains properties defined in the schema.
func (s *Schema) validateAdditionalProperties(instance map[string]any) error {
// Note: AdditionalProperties has the type any.
if s.AdditionalProperties != false {
return nil
}
for k := range instance {
_, ok := s.Properties[k]
if !ok {
return fmt.Errorf("property %s is not defined in the schema", k)
}
}
return nil
}
// This function validates that all require properties in the schema have values
// in the instance.
func (s *Schema) validateRequired(instance map[string]any) error {
for _, name := range s.Required {
if _, ok := instance[name]; !ok {
return fmt.Errorf("no value provided for required property %s", name)
}
}
return nil
}
// Validates the types of all input properties values match their types defined in the schema
func (s *Schema) validateTypes(instance map[string]any) error {
for k, v := range instance {
fieldInfo, ok := s.Properties[k]
if !ok {
continue
}
err := validateType(v, fieldInfo.Type)
if err != nil {
return fmt.Errorf("incorrect type for property %s: %w", k, err)
}
}
return nil
}
func (s *Schema) validateEnum(instance map[string]any) error {
for k, v := range instance {
fieldInfo, ok := s.Properties[k]
if !ok {
continue
}
if fieldInfo.Enum == nil {
continue
}
if !slices.Contains(fieldInfo.Enum, v) {
return fmt.Errorf("expected value of property %s to be one of %v. Found: %v", k, fieldInfo.Enum, v)
}
}
return nil
}

View File

@ -0,0 +1,155 @@
package jsonschema
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestValidateInstanceAdditionalPropertiesPermitted(t *testing.T) {
instance := map[string]any{
"int_val": 1,
"float_val": 1.0,
"bool_val": false,
"an_additional_property": "abc",
}
schema, err := Load("./testdata/instance-validate/test-schema.json")
require.NoError(t, err)
err = schema.validateAdditionalProperties(instance)
assert.NoError(t, err)
err = schema.ValidateInstance(instance)
assert.NoError(t, err)
}
func TestValidateInstanceAdditionalPropertiesForbidden(t *testing.T) {
instance := map[string]any{
"int_val": 1,
"float_val": 1.0,
"bool_val": false,
"an_additional_property": "abc",
}
schema, err := Load("./testdata/instance-validate/test-schema-no-additional-properties.json")
require.NoError(t, err)
err = schema.validateAdditionalProperties(instance)
assert.EqualError(t, err, "property an_additional_property is not defined in the schema")
err = schema.ValidateInstance(instance)
assert.EqualError(t, err, "property an_additional_property is not defined in the schema")
instanceWOAdditionalProperties := map[string]any{
"int_val": 1,
"float_val": 1.0,
"bool_val": false,
}
err = schema.validateAdditionalProperties(instanceWOAdditionalProperties)
assert.NoError(t, err)
err = schema.ValidateInstance(instanceWOAdditionalProperties)
assert.NoError(t, err)
}
func TestValidateInstanceTypes(t *testing.T) {
schema, err := Load("./testdata/instance-validate/test-schema.json")
require.NoError(t, err)
validInstance := map[string]any{
"int_val": 1,
"float_val": 1.0,
"bool_val": false,
}
err = schema.validateTypes(validInstance)
assert.NoError(t, err)
err = schema.ValidateInstance(validInstance)
assert.NoError(t, err)
invalidInstance := map[string]any{
"int_val": "abc",
"float_val": 1.0,
"bool_val": false,
}
err = schema.validateTypes(invalidInstance)
assert.EqualError(t, err, "incorrect type for property int_val: expected type integer, but value is \"abc\"")
err = schema.ValidateInstance(invalidInstance)
assert.EqualError(t, err, "incorrect type for property int_val: expected type integer, but value is \"abc\"")
}
func TestValidateInstanceRequired(t *testing.T) {
schema, err := Load("./testdata/instance-validate/test-schema-some-fields-required.json")
require.NoError(t, err)
validInstance := map[string]any{
"int_val": 1,
"float_val": 1.0,
"bool_val": false,
}
err = schema.validateRequired(validInstance)
assert.NoError(t, err)
err = schema.ValidateInstance(validInstance)
assert.NoError(t, err)
invalidInstance := map[string]any{
"string_val": "abc",
"float_val": 1.0,
"bool_val": false,
}
err = schema.validateRequired(invalidInstance)
assert.EqualError(t, err, "no value provided for required property int_val")
err = schema.ValidateInstance(invalidInstance)
assert.EqualError(t, err, "no value provided for required property int_val")
}
func TestLoadInstance(t *testing.T) {
schema, err := Load("./testdata/instance-validate/test-schema.json")
require.NoError(t, err)
// Expect the instance to be loaded successfully.
instance, err := schema.LoadInstance("./testdata/instance-load/valid-instance.json")
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"bool_val": false,
"int_val": int64(1),
"string_val": "abc",
"float_val": 2.0,
}, instance)
// Expect instance validation against the schema to fail.
_, err = schema.LoadInstance("./testdata/instance-load/invalid-type-instance.json")
assert.EqualError(t, err, "incorrect type for property string_val: expected type string, but value is 123")
}
func TestValidateInstanceEnum(t *testing.T) {
schema, err := Load("./testdata/instance-validate/test-schema-enum.json")
require.NoError(t, err)
validInstance := map[string]any{
"foo": "b",
"bar": int64(6),
}
assert.NoError(t, schema.validateEnum(validInstance))
assert.NoError(t, schema.ValidateInstance(validInstance))
invalidStringInstance := map[string]any{
"foo": "d",
"bar": int64(2),
}
assert.EqualError(t, schema.validateEnum(invalidStringInstance), "expected value of property foo to be one of [a b c]. Found: d")
assert.EqualError(t, schema.ValidateInstance(invalidStringInstance), "expected value of property foo to be one of [a b c]. Found: d")
invalidIntInstance := map[string]any{
"foo": "a",
"bar": int64(1),
}
assert.EqualError(t, schema.validateEnum(invalidIntInstance), "expected value of property bar to be one of [2 4 6]. Found: 1")
assert.EqualError(t, schema.ValidateInstance(invalidIntInstance), "expected value of property bar to be one of [2 4 6]. Found: 1")
}

Some files were not shown because too many files have changed in this diff Show More