diff --git a/bundle/config/mutator/configure_wsfs.go b/bundle/config/mutator/configure_wsfs.go index 1d1bec58..296536d1 100644 --- a/bundle/config/mutator/configure_wsfs.go +++ b/bundle/config/mutator/configure_wsfs.go @@ -6,13 +6,11 @@ import ( "github.com/databricks/cli/bundle" "github.com/databricks/cli/libs/diag" - "github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/filer" + "github.com/databricks/cli/libs/runtime" "github.com/databricks/cli/libs/vfs" ) -const envDatabricksRuntimeVersion = "DATABRICKS_RUNTIME_VERSION" - type configureWSFS struct{} func ConfigureWSFS() bundle.Mutator { @@ -32,7 +30,7 @@ func (m *configureWSFS) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagno } // The executable must be running on DBR. - if _, ok := env.Lookup(ctx, envDatabricksRuntimeVersion); !ok { + if !runtime.RunsOnDatabricks(ctx) { return nil } diff --git a/libs/notebook/detect.go b/libs/notebook/detect.go index 582a8847..0d49b3da 100644 --- a/libs/notebook/detect.go +++ b/libs/notebook/detect.go @@ -139,3 +139,37 @@ func Detect(name string) (notebook bool, language workspace.Language, err error) b := filepath.Base(name) return DetectWithFS(os.DirFS(d), b) } + +type inMemoryFile struct { + buffer bytes.Buffer +} + +type inMemoryFS struct { + content []byte +} + +func (f *inMemoryFile) Close() error { + return nil +} + +func (f *inMemoryFile) Stat() (fs.FileInfo, error) { + return nil, nil +} + +func (f *inMemoryFile) Read(b []byte) (n int, err error) { + return f.buffer.Read(b) +} + +func (fs inMemoryFS) Open(name string) (fs.File, error) { + return &inMemoryFile{ + buffer: *bytes.NewBuffer(fs.content), + }, nil +} + +func DetectWithContent(name string, content []byte) (notebook bool, language workspace.Language, err error) { + fs := inMemoryFS{ + content: content, + } + + return DetectWithFS(fs, name) +} diff --git a/libs/notebook/detect_test.go b/libs/notebook/detect_test.go index ad89d6dd..a5892fe7 100644 --- a/libs/notebook/detect_test.go +++ b/libs/notebook/detect_test.go @@ -117,3 +117,22 @@ func TestDetectWithObjectInfo(t *testing.T) { assert.True(t, nb) assert.Equal(t, workspace.LanguagePython, lang) } + +func TestInMemoryFiles(t *testing.T) { + isNotebook, language, err := DetectWithContent("hello.py", []byte("# Databricks notebook source\n print('hello')")) + assert.True(t, isNotebook) + assert.Equal(t, workspace.LanguagePython, language) + require.NoError(t, err) + + isNotebook, language, err = DetectWithContent("hello.py", []byte("print('hello')")) + assert.False(t, isNotebook) + assert.Equal(t, workspace.Language(""), language) + require.NoError(t, err) + + fileContent, err := os.ReadFile("./testdata/py_ipynb.ipynb") + require.NoError(t, err) + isNotebook, language, err = DetectWithContent("py_ipynb.ipynb", fileContent) + assert.True(t, isNotebook) + assert.Equal(t, workspace.LanguagePython, language) + require.NoError(t, err) +} diff --git a/libs/runtime/detect.go b/libs/runtime/detect.go new file mode 100644 index 00000000..765eea8d --- /dev/null +++ b/libs/runtime/detect.go @@ -0,0 +1,14 @@ +package runtime + +import ( + "context" + + "github.com/databricks/cli/libs/env" +) + +const envDatabricksRuntimeVersion = "DATABRICKS_RUNTIME_VERSION" + +func RunsOnDatabricks(ctx context.Context) bool { + value, ok := env.Lookup(ctx, envDatabricksRuntimeVersion) + return value != "" && ok +} diff --git a/libs/runtime/detect_test.go b/libs/runtime/detect_test.go new file mode 100644 index 00000000..ccc2254e --- /dev/null +++ b/libs/runtime/detect_test.go @@ -0,0 +1,18 @@ +package runtime + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRunsOnDatabricks(t *testing.T) { + ctx := context.Background() + + t.Setenv("DATABRICKS_RUNTIME_VERSION", "") + assert.False(t, RunsOnDatabricks(ctx)) + + t.Setenv("DATABRICKS_RUNTIME_VERSION", "14.3") + assert.True(t, RunsOnDatabricks(ctx)) +} diff --git a/libs/template/file.go b/libs/template/file.go index aafb1acf..826ded37 100644 --- a/libs/template/file.go +++ b/libs/template/file.go @@ -2,12 +2,19 @@ package template import ( "context" + "encoding/base64" "io" "io/fs" "os" "path/filepath" + "strings" + "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/filer" + "github.com/databricks/cli/libs/log" + "github.com/databricks/cli/libs/notebook" + "github.com/databricks/cli/libs/runtime" + "github.com/databricks/databricks-sdk-go/service/workspace" ) // Interface representing a file to be materialized from a template into a project @@ -68,16 +75,20 @@ func (f *copyFile) PersistToDisk() error { return err } defer srcFile.Close() - dstFile, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, f.perm) + + // we read the full file into memory because we need to inspect the content + // in order to determine if it is a notebook + // Once we stop using the workspace API, we can remove this and write in a streaming fashion + content, err := io.ReadAll(srcFile) if err != nil { return err } - defer dstFile.Close() - _, err = io.Copy(dstFile, srcFile) - return err + return writeFile(f.ctx, path, content, f.perm) } type inMemoryFile struct { + ctx context.Context + dstPath *destinationPath content []byte @@ -97,5 +108,37 @@ func (f *inMemoryFile) PersistToDisk() error { if err != nil { return err } - return os.WriteFile(path, f.content, f.perm) + + return writeFile(f.ctx, path, f.content, f.perm) +} + +func shouldUseImportNotebook(ctx context.Context, path string, content []byte) bool { + if strings.HasPrefix(path, "/Workspace/") && runtime.RunsOnDatabricks(ctx) { + isNotebook, _, err := notebook.DetectWithContent(path, content) + if err != nil { + log.Debugf(ctx, "Error detecting notebook: %v", err) + } + return isNotebook && err == nil + } + + return false +} + +func writeFile(ctx context.Context, path string, content []byte, perm fs.FileMode) error { + if shouldUseImportNotebook(ctx, path, content) { + return importNotebook(ctx, path, content) + } else { + return os.WriteFile(path, content, perm) + } +} + +func importNotebook(ctx context.Context, path string, content []byte) error { + w := root.WorkspaceClient(ctx) + + return w.Workspace.Import(ctx, workspace.Import{ + Format: "AUTO", + Overwrite: false, + Path: path, + Content: base64.StdEncoding.EncodeToString(content), + }) } diff --git a/libs/template/file_test.go b/libs/template/file_test.go index 85938895..e1b37e55 100644 --- a/libs/template/file_test.go +++ b/libs/template/file_test.go @@ -8,6 +8,11 @@ import ( "runtime" "testing" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/workspace" + "github.com/stretchr/testify/mock" + + "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/filer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,6 +22,7 @@ func testInMemoryFile(t *testing.T, perm fs.FileMode) { tmpDir := t.TempDir() f := &inMemoryFile{ + ctx: context.Background(), dstPath: &destinationPath{ root: tmpDir, relPath: "a/b/c", @@ -109,3 +115,37 @@ func TestTemplateCopyFilePersistToDiskForWindows(t *testing.T) { // fs.FileMode values we can use for different operating systems. testCopyFile(t, 0666) } + +func TestShouldUseImportNotebook(t *testing.T) { + ctx := context.Background() + data := []byte("# Databricks notebook source\n print('hello')") + + assert.False(t, shouldUseImportNotebook(ctx, "./foo/bar", data)) + assert.False(t, shouldUseImportNotebook(ctx, "./foo/bar.ipynb", data)) + assert.False(t, shouldUseImportNotebook(ctx, "/Workspace/foo/bar", data)) + assert.False(t, shouldUseImportNotebook(ctx, "/Workspace/foo/bar.ipynb", data)) + + t.Setenv("DATABRICKS_RUNTIME_VERSION", "14.3") + assert.False(t, shouldUseImportNotebook(ctx, "./foo/bar", data)) + assert.False(t, shouldUseImportNotebook(ctx, "./foo/bar.ipynb", data)) + assert.False(t, shouldUseImportNotebook(ctx, "/Workspace/foo/bar", data)) + assert.True(t, shouldUseImportNotebook(ctx, "/Workspace/foo/bar.py", data)) +} + +func TestImportNotebook(t *testing.T) { + ctx := context.Background() + + m := mocks.NewMockWorkspaceClient(t) + ctx = root.SetWorkspaceClient(ctx, m.WorkspaceClient) + + workspaceApi := m.GetMockWorkspaceAPI() + workspaceApi.EXPECT().Import(mock.Anything, workspace.Import{ + Content: "cXdlcnR5", // base64 of "qwerty" + Format: "AUTO", + Overwrite: false, + Path: "/Workspace/foo/bar.ipynb", + }).Return(nil) + + err := importNotebook(ctx, "/Workspace/foo/bar.ipynb", []byte("qwerty")) + assert.NoError(t, err) +} diff --git a/libs/template/renderer.go b/libs/template/renderer.go index 827f3013..c464df3f 100644 --- a/libs/template/renderer.go +++ b/libs/template/renderer.go @@ -153,12 +153,18 @@ func (r *renderer) computeFile(relPathTemplate string) (file, error) { return nil, err } + // we need the absolute path in case we need to write notebooks using the REST API + rootPath, err := filepath.Abs(r.instanceRoot) + if err != nil { + return nil, err + } + // If file name does not specify the `.tmpl` extension, then it is copied // over as is, without treating it as a template if !strings.HasSuffix(relPathTemplate, templateExtension) { return ©File{ dstPath: &destinationPath{ - root: r.instanceRoot, + root: rootPath, relPath: relPath, }, perm: perm, @@ -194,8 +200,9 @@ func (r *renderer) computeFile(relPathTemplate string) (file, error) { } return &inMemoryFile{ + ctx: r.ctx, dstPath: &destinationPath{ - root: r.instanceRoot, + root: rootPath, relPath: relPath, }, perm: perm, @@ -314,7 +321,7 @@ func (r *renderer) persistToDisk() error { if err == nil { return fmt.Errorf("failed to initialize template, one or more files already exist: %s", path) } - if err != nil && !errors.Is(err, fs.ErrNotExist) { + if !errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("error while verifying file %s does not already exist: %w", path, err) } } diff --git a/libs/template/renderer_test.go b/libs/template/renderer_test.go index 92133c5f..25850126 100644 --- a/libs/template/renderer_test.go +++ b/libs/template/renderer_test.go @@ -329,6 +329,7 @@ func TestRendererPersistToDisk(t *testing.T) { skipPatterns: []string{"a/b/c", "mn*"}, files: []file{ &inMemoryFile{ + ctx: ctx, dstPath: &destinationPath{ root: tmpDir, relPath: "a/b/c", @@ -337,6 +338,7 @@ func TestRendererPersistToDisk(t *testing.T) { content: nil, }, &inMemoryFile{ + ctx: ctx, dstPath: &destinationPath{ root: tmpDir, relPath: "mno", @@ -345,6 +347,7 @@ func TestRendererPersistToDisk(t *testing.T) { content: nil, }, &inMemoryFile{ + ctx: ctx, dstPath: &destinationPath{ root: tmpDir, relPath: "a/b/d", @@ -353,6 +356,7 @@ func TestRendererPersistToDisk(t *testing.T) { content: []byte("123"), }, &inMemoryFile{ + ctx: ctx, dstPath: &destinationPath{ root: tmpDir, relPath: "mmnn",