From d38649088cc31e0b12899cfec9013c13e1ad8df0 Mon Sep 17 00:00:00 2001 From: shreyas-goenka <88374338+shreyas-goenka@users.noreply.github.com> Date: Mon, 12 Jun 2023 21:03:46 +0200 Subject: [PATCH] Add workspace import-dir command (#456) ## Tests Testing using integration tests and manually --- cmd/workspace/workspace/events.go | 31 +++- cmd/workspace/workspace/import_dir.go | 138 ++++++++++++++++++ internal/filer_test.go | 137 +++++++++++++++++ internal/testdata/import_dir/a/b/c/file-b | 1 + internal/testdata/import_dir/file-a | 1 + .../testdata/import_dir/jupyterNotebook.ipynb | 21 +++ internal/testdata/import_dir/pyNotebook.py | 2 + internal/testdata/import_dir/rNotebook.r | 2 + .../testdata/import_dir/scalaNotebook.scala | 2 + internal/testdata/import_dir/sqlNotebook.sql | 2 + internal/workspace_test.go | 98 +++++++++++-- libs/filer/workspace_files_client.go | 16 +- 12 files changed, 439 insertions(+), 12 deletions(-) create mode 100644 cmd/workspace/workspace/import_dir.go create mode 100644 internal/testdata/import_dir/a/b/c/file-b create mode 100644 internal/testdata/import_dir/file-a create mode 100644 internal/testdata/import_dir/jupyterNotebook.ipynb create mode 100644 internal/testdata/import_dir/pyNotebook.py create mode 100644 internal/testdata/import_dir/rNotebook.r create mode 100644 internal/testdata/import_dir/scalaNotebook.scala create mode 100644 internal/testdata/import_dir/sqlNotebook.sql diff --git a/cmd/workspace/workspace/events.go b/cmd/workspace/workspace/events.go index c4eb0f74..3a51bc44 100644 --- a/cmd/workspace/workspace/events.go +++ b/cmd/workspace/workspace/events.go @@ -9,10 +9,15 @@ type fileIOEvent struct { type EventType string const ( - EventTypeFileExported = EventType("FILE_EXPORTED") + EventTypeFileExported = EventType("FILE_EXPORTED") + EventTypeFileSkipped = EventType("FILE_SKIPPED") + EventTypeFileImported = EventType("FILE_IMPORTED") + EventTypeExportStarted = EventType("EXPORT_STARTED") EventTypeExportCompleted = EventType("EXPORT_COMPLETED") - EventTypeFileSkipped = EventType("FILE_SKIPPED") + + EventTypeImportStarted = EventType("IMPORT_STARTED") + EventTypeImportCompleted = EventType("IMPORT_COMPLETED") ) func newFileExportedEvent(sourcePath, targetPath string) fileIOEvent { @@ -44,3 +49,25 @@ func newExportStartedEvent(sourcePath string) fileIOEvent { Type: EventTypeExportStarted, } } + +func newImportStartedEvent(sourcePath string) fileIOEvent { + return fileIOEvent{ + SourcePath: sourcePath, + Type: EventTypeImportStarted, + } +} + +func newImportCompletedEvent(targetPath string) fileIOEvent { + return fileIOEvent{ + TargetPath: targetPath, + Type: EventTypeImportCompleted, + } +} + +func newFileImportedEvent(sourcePath, targetPath string) fileIOEvent { + return fileIOEvent{ + TargetPath: targetPath, + SourcePath: sourcePath, + Type: EventTypeFileImported, + } +} diff --git a/cmd/workspace/workspace/import_dir.go b/cmd/workspace/workspace/import_dir.go new file mode 100644 index 00000000..af9c38ca --- /dev/null +++ b/cmd/workspace/workspace/import_dir.go @@ -0,0 +1,138 @@ +package workspace + +import ( + "context" + "errors" + "io/fs" + "os" + "path" + "path/filepath" + "strings" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/filer" + "github.com/databricks/cli/libs/notebook" + "github.com/spf13/cobra" +) + +// The callback function imports the file specified at sourcePath. This function is +// meant to be used in conjunction with fs.WalkDir +// +// We deal with 3 different names for files. The need for this +// arises due to workspace API behaviour and limitations +// +// 1. Local name: The name for the file in the local file system +// 2. Remote name: The name of the file as materialized in the workspace +// 3. API payload name: The name to be used for API calls +// +// Example, consider the notebook "foo\\myNotebook.py" on a windows file system. +// The process to upload it would look like +// 1. Read the notebook, referring to it using it's local name "foo\\myNotebook.py" +// 2. API call to import the notebook to the workspace, using it API payload name "foo/myNotebook.py" +// 3. The notebook is materialized in the workspace using it's remote name "foo/myNotebook" +func importFileCallback(ctx context.Context, workspaceFiler filer.Filer, sourceDir, targetDir string) func(string, fs.DirEntry, error) error { + return func(sourcePath string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + // localName is the name for the file in the local file system + localName, err := filepath.Rel(sourceDir, sourcePath) + if err != nil { + return err + } + + // nameForApiCall is the name for the file to be used in any API call. + // This is a file name we provide to the filer.Write and Mkdir methods + nameForApiCall := filepath.ToSlash(localName) + + // create directory and return early + if d.IsDir() { + return workspaceFiler.Mkdir(ctx, nameForApiCall) + } + + // remoteName is the name of the file as visible in the workspace. We compute + // the remote name on the client side for logging purposes + remoteName := filepath.ToSlash(localName) + isNotebook, _, err := notebook.Detect(sourcePath) + if err != nil { + return err + } + if isNotebook { + ext := path.Ext(localName) + remoteName = strings.TrimSuffix(localName, ext) + } + + // Open the local file + f, err := os.Open(sourcePath) + if err != nil { + return err + } + defer f.Close() + + // Create file in WSFS + if importOverwrite { + err = workspaceFiler.Write(ctx, nameForApiCall, f, filer.OverwriteIfExists) + if err != nil { + return err + } + } else { + err = workspaceFiler.Write(ctx, nameForApiCall, f) + if errors.Is(err, fs.ErrExist) { + // Emit file skipped event with the appropriate template + fileSkippedEvent := newFileSkippedEvent(localName, path.Join(targetDir, remoteName)) + template := "{{.SourcePath}} -> {{.TargetPath}} (skipped; already exists)\n" + return cmdio.RenderWithTemplate(ctx, fileSkippedEvent, template) + } + if err != nil { + return err + } + } + fileImportedEvent := newFileImportedEvent(localName, path.Join(targetDir, remoteName)) + return cmdio.RenderWithTemplate(ctx, fileImportedEvent, "{{.SourcePath}} -> {{.TargetPath}}\n") + } +} + +var importDirCommand = &cobra.Command{ + Use: "import-dir SOURCE_PATH TARGET_PATH", + Short: `Import a directory from the local filesystem to a Databricks workspace.`, + Long: ` +Import a directory recursively from the local file system to a Databricks workspace. +Notebooks will have their extensions (one of .scala, .py, .sql, .ipynb, .r) stripped +`, + PreRunE: root.MustWorkspaceClient, + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) (err error) { + ctx := cmd.Context() + w := root.WorkspaceClient(ctx) + sourceDir := args[0] + targetDir := args[1] + + // Initialize a filer rooted at targetDir + workspaceFiler, err := filer.NewWorkspaceFilesClient(w, targetDir) + if err != nil { + return err + } + + // TODO: print progress events on stderr instead: https://github.com/databricks/cli/issues/448 + err = cmdio.RenderJson(ctx, newImportStartedEvent(sourceDir)) + if err != nil { + return err + } + + // Walk local directory tree and import files to the workspace + err = filepath.WalkDir(sourceDir, importFileCallback(ctx, workspaceFiler, sourceDir, targetDir)) + if err != nil { + return err + } + return cmdio.RenderJson(ctx, newImportCompletedEvent(targetDir)) + }, +} + +var importOverwrite bool + +func init() { + importDirCommand.Flags().BoolVar(&importOverwrite, "overwrite", false, "overwrite existing workspace files") + Cmd.AddCommand(importDirCommand) +} diff --git a/internal/filer_test.go b/internal/filer_test.go index f69ee547..bc005feb 100644 --- a/internal/filer_test.go +++ b/internal/filer_test.go @@ -8,6 +8,7 @@ import ( "io" "io/fs" "net/http" + "regexp" "strings" "testing" @@ -312,6 +313,142 @@ func TestAccFilerDbfsReadDir(t *testing.T) { runFilerReadDirTest(t, ctx, f) } +var jupyterNotebookContent1 = ` +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Jupyter Notebook Version 1\")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 + } +` + +var jupyterNotebookContent2 = ` +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Jupyter Notebook Version 2\")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 + } +` + +func TestAccFilerWorkspaceNotebookConflict(t *testing.T) { + ctx, f := setupWorkspaceFilesTest(t) + var err error + + // Upload the notebooks + err = f.Write(ctx, "pyNb.py", strings.NewReader("# Databricks notebook source\nprint('first upload'))")) + require.NoError(t, err) + err = f.Write(ctx, "rNb.r", strings.NewReader("# Databricks notebook source\nprint('first upload'))")) + require.NoError(t, err) + err = f.Write(ctx, "sqlNb.sql", strings.NewReader("-- Databricks notebook source\n SELECT \"first upload\"")) + require.NoError(t, err) + err = f.Write(ctx, "scalaNb.scala", strings.NewReader("// Databricks notebook source\n println(\"first upload\"))")) + require.NoError(t, err) + err = f.Write(ctx, "jupyterNb.ipynb", strings.NewReader(jupyterNotebookContent1)) + require.NoError(t, err) + + // Assert contents after initial upload + filerTest{t, f}.assertContents(ctx, "pyNb", "# Databricks notebook source\nprint('first upload'))") + filerTest{t, f}.assertContents(ctx, "rNb", "# Databricks notebook source\nprint('first upload'))") + filerTest{t, f}.assertContents(ctx, "sqlNb", "-- Databricks notebook source\n SELECT \"first upload\"") + filerTest{t, f}.assertContents(ctx, "scalaNb", "// Databricks notebook source\n println(\"first upload\"))") + filerTest{t, f}.assertContents(ctx, "jupyterNb", "# Databricks notebook source\nprint(\"Jupyter Notebook Version 1\")") + + // Assert uploading a second time fails due to overwrite mode missing + err = f.Write(ctx, "pyNb.py", strings.NewReader("# Databricks notebook source\nprint('second upload'))")) + assert.ErrorIs(t, err, fs.ErrExist) + assert.Regexp(t, regexp.MustCompile(`file already exists: .*/pyNb$`), err.Error()) + + err = f.Write(ctx, "rNb.r", strings.NewReader("# Databricks notebook source\nprint('second upload'))")) + assert.ErrorIs(t, err, fs.ErrExist) + assert.Regexp(t, regexp.MustCompile(`file already exists: .*/rNb$`), err.Error()) + + err = f.Write(ctx, "sqlNb.sql", strings.NewReader("# Databricks notebook source\n SELECT \"second upload\")")) + assert.ErrorIs(t, err, fs.ErrExist) + assert.Regexp(t, regexp.MustCompile(`file already exists: .*/sqlNb$`), err.Error()) + + err = f.Write(ctx, "scalaNb.scala", strings.NewReader("# Databricks notebook source\n println(\"second upload\"))")) + assert.ErrorIs(t, err, fs.ErrExist) + assert.Regexp(t, regexp.MustCompile(`file already exists: .*/scalaNb$`), err.Error()) + + err = f.Write(ctx, "jupyterNb.ipynb", strings.NewReader(jupyterNotebookContent2)) + assert.ErrorIs(t, err, fs.ErrExist) + assert.Regexp(t, regexp.MustCompile(`file already exists: .*/jupyterNb$`), err.Error()) +} + +func TestAccFilerWorkspaceNotebookWithOverwriteFlag(t *testing.T) { + ctx, f := setupWorkspaceFilesTest(t) + var err error + + // Upload notebooks + err = f.Write(ctx, "pyNb.py", strings.NewReader("# Databricks notebook source\nprint('first upload'))")) + require.NoError(t, err) + err = f.Write(ctx, "rNb.r", strings.NewReader("# Databricks notebook source\nprint('first upload'))")) + require.NoError(t, err) + err = f.Write(ctx, "sqlNb.sql", strings.NewReader("-- Databricks notebook source\n SELECT \"first upload\"")) + require.NoError(t, err) + err = f.Write(ctx, "scalaNb.scala", strings.NewReader("// Databricks notebook source\n println(\"first upload\"))")) + require.NoError(t, err) + err = f.Write(ctx, "jupyterNb.ipynb", strings.NewReader(jupyterNotebookContent1)) + require.NoError(t, err) + + // Assert contents after initial upload + filerTest{t, f}.assertContents(ctx, "pyNb", "# Databricks notebook source\nprint('first upload'))") + filerTest{t, f}.assertContents(ctx, "rNb", "# Databricks notebook source\nprint('first upload'))") + filerTest{t, f}.assertContents(ctx, "sqlNb", "-- Databricks notebook source\n SELECT \"first upload\"") + filerTest{t, f}.assertContents(ctx, "scalaNb", "// Databricks notebook source\n println(\"first upload\"))") + filerTest{t, f}.assertContents(ctx, "jupyterNb", "# Databricks notebook source\nprint(\"Jupyter Notebook Version 1\")") + + // Upload notebooks a second time, overwriting the initial uplaods + err = f.Write(ctx, "pyNb.py", strings.NewReader("# Databricks notebook source\nprint('second upload'))"), filer.OverwriteIfExists) + require.NoError(t, err) + err = f.Write(ctx, "rNb.r", strings.NewReader("# Databricks notebook source\nprint('second upload'))"), filer.OverwriteIfExists) + require.NoError(t, err) + err = f.Write(ctx, "sqlNb.sql", strings.NewReader("-- Databricks notebook source\n SELECT \"second upload\""), filer.OverwriteIfExists) + require.NoError(t, err) + err = f.Write(ctx, "scalaNb.scala", strings.NewReader("// Databricks notebook source\n println(\"second upload\"))"), filer.OverwriteIfExists) + require.NoError(t, err) + err = f.Write(ctx, "jupyterNb.ipynb", strings.NewReader(jupyterNotebookContent2), filer.OverwriteIfExists) + require.NoError(t, err) + + // Assert contents have been overwritten + filerTest{t, f}.assertContents(ctx, "pyNb", "# Databricks notebook source\nprint('second upload'))") + filerTest{t, f}.assertContents(ctx, "rNb", "# Databricks notebook source\nprint('second upload'))") + filerTest{t, f}.assertContents(ctx, "sqlNb", "-- Databricks notebook source\n SELECT \"second upload\"") + filerTest{t, f}.assertContents(ctx, "scalaNb", "// Databricks notebook source\n println(\"second upload\"))") + filerTest{t, f}.assertContents(ctx, "jupyterNb", "# Databricks notebook source\nprint(\"Jupyter Notebook Version 2\")") +} + func setupFilerLocalTest(t *testing.T) (context.Context, filer.Filer) { ctx := context.Background() f, err := filer.NewLocalClient(t.TempDir()) diff --git a/internal/testdata/import_dir/a/b/c/file-b b/internal/testdata/import_dir/a/b/c/file-b new file mode 100644 index 00000000..976395cf --- /dev/null +++ b/internal/testdata/import_dir/a/b/c/file-b @@ -0,0 +1 @@ +file-in-dir diff --git a/internal/testdata/import_dir/file-a b/internal/testdata/import_dir/file-a new file mode 100644 index 00000000..4b5fa637 --- /dev/null +++ b/internal/testdata/import_dir/file-a @@ -0,0 +1 @@ +hello, world diff --git a/internal/testdata/import_dir/jupyterNotebook.ipynb b/internal/testdata/import_dir/jupyterNotebook.ipynb new file mode 100644 index 00000000..511115a7 --- /dev/null +++ b/internal/testdata/import_dir/jupyterNotebook.ipynb @@ -0,0 +1,21 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"jupyter\")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/internal/testdata/import_dir/pyNotebook.py b/internal/testdata/import_dir/pyNotebook.py new file mode 100644 index 00000000..a5122a3e --- /dev/null +++ b/internal/testdata/import_dir/pyNotebook.py @@ -0,0 +1,2 @@ +# Databricks notebook source +print("python") diff --git a/internal/testdata/import_dir/rNotebook.r b/internal/testdata/import_dir/rNotebook.r new file mode 100644 index 00000000..2e581c33 --- /dev/null +++ b/internal/testdata/import_dir/rNotebook.r @@ -0,0 +1,2 @@ +# Databricks notebook source +print("r") diff --git a/internal/testdata/import_dir/scalaNotebook.scala b/internal/testdata/import_dir/scalaNotebook.scala new file mode 100644 index 00000000..7e4284ac --- /dev/null +++ b/internal/testdata/import_dir/scalaNotebook.scala @@ -0,0 +1,2 @@ +// Databricks notebook source +println("scala") diff --git a/internal/testdata/import_dir/sqlNotebook.sql b/internal/testdata/import_dir/sqlNotebook.sql new file mode 100644 index 00000000..9b5a6d31 --- /dev/null +++ b/internal/testdata/import_dir/sqlNotebook.sql @@ -0,0 +1,2 @@ +-- Databricks notebook source +SELECT "sql" diff --git a/internal/workspace_test.go b/internal/workspace_test.go index bfa323c5..b65df8ce 100644 --- a/internal/workspace_test.go +++ b/internal/workspace_test.go @@ -3,6 +3,7 @@ package internal import ( "context" "errors" + "io/ioutil" "net/http" "os" "path/filepath" @@ -58,13 +59,21 @@ func setupWorkspaceImportExportTest(t *testing.T) (context.Context, filer.Filer, } // TODO: add tests for the progress event output logs: https://github.com/databricks/cli/issues/447 -func assertFileContents(t *testing.T, path string, content string) { +func assertLocalFileContents(t *testing.T, path string, content string) { require.FileExists(t, path) b, err := os.ReadFile(path) require.NoError(t, err) assert.Contains(t, string(b), content) } +func assertFilerFileContents(t *testing.T, ctx context.Context, f filer.Filer, path string, content string) { + r, err := f.Read(ctx, path) + require.NoError(t, err) + b, err := ioutil.ReadAll(r) + require.NoError(t, err) + assert.Contains(t, string(b), content) +} + func TestAccExportDir(t *testing.T) { ctx, f, sourceDir := setupWorkspaceImportExportTest(t) targetDir := t.TempDir() @@ -89,12 +98,12 @@ func TestAccExportDir(t *testing.T) { RequireSuccessfulRun(t, "workspace", "export-dir", sourceDir, targetDir) // Assert files were exported - assertFileContents(t, filepath.Join(targetDir, "file-a"), "abc") - assertFileContents(t, filepath.Join(targetDir, "pyNotebook.py"), "# Databricks notebook source") - assertFileContents(t, filepath.Join(targetDir, "sqlNotebook.sql"), "-- Databricks notebook source") - assertFileContents(t, filepath.Join(targetDir, "rNotebook.r"), "# Databricks notebook source") - assertFileContents(t, filepath.Join(targetDir, "scalaNotebook.scala"), "// Databricks notebook source") - assertFileContents(t, filepath.Join(targetDir, "a/b/c/file-b"), "def") + assertLocalFileContents(t, filepath.Join(targetDir, "file-a"), "abc") + assertLocalFileContents(t, filepath.Join(targetDir, "pyNotebook.py"), "# Databricks notebook source") + assertLocalFileContents(t, filepath.Join(targetDir, "sqlNotebook.sql"), "-- Databricks notebook source") + assertLocalFileContents(t, filepath.Join(targetDir, "rNotebook.r"), "# Databricks notebook source") + assertLocalFileContents(t, filepath.Join(targetDir, "scalaNotebook.scala"), "// Databricks notebook source") + assertLocalFileContents(t, filepath.Join(targetDir, "a/b/c/file-b"), "def") } func TestAccExportDirDoesNotOverwrite(t *testing.T) { @@ -115,7 +124,7 @@ func TestAccExportDirDoesNotOverwrite(t *testing.T) { RequireSuccessfulRun(t, "workspace", "export-dir", sourceDir, targetDir) // Assert file is not overwritten - assertFileContents(t, filepath.Join(targetDir, "file-a"), "local content") + assertLocalFileContents(t, filepath.Join(targetDir, "file-a"), "local content") } func TestAccExportDirWithOverwriteFlag(t *testing.T) { @@ -136,5 +145,76 @@ func TestAccExportDirWithOverwriteFlag(t *testing.T) { RequireSuccessfulRun(t, "workspace", "export-dir", sourceDir, targetDir, "--overwrite") // Assert file has been overwritten - assertFileContents(t, filepath.Join(targetDir, "file-a"), "content from workspace") + assertLocalFileContents(t, filepath.Join(targetDir, "file-a"), "content from workspace") +} + +// TODO: Add assertions on progress logs for workspace import-dir command. https://github.com/databricks/cli/issues/455 +func TestAccImportDir(t *testing.T) { + ctx, workspaceFiler, targetDir := setupWorkspaceImportExportTest(t) + RequireSuccessfulRun(t, "workspace", "import-dir", "./testdata/import_dir", targetDir, "--log-level=debug") + + // Assert files are imported + assertFilerFileContents(t, ctx, workspaceFiler, "file-a", "hello, world") + assertFilerFileContents(t, ctx, workspaceFiler, "a/b/c/file-b", "file-in-dir") + assertFilerFileContents(t, ctx, workspaceFiler, "pyNotebook", "# Databricks notebook source\nprint(\"python\")") + assertFilerFileContents(t, ctx, workspaceFiler, "sqlNotebook", "-- Databricks notebook source\nSELECT \"sql\"") + assertFilerFileContents(t, ctx, workspaceFiler, "rNotebook", "# Databricks notebook source\nprint(\"r\")") + assertFilerFileContents(t, ctx, workspaceFiler, "scalaNotebook", "// Databricks notebook source\nprintln(\"scala\")") + assertFilerFileContents(t, ctx, workspaceFiler, "jupyterNotebook", "# Databricks notebook source\nprint(\"jupyter\")") +} + +func TestAccImportDirDoesNotOverwrite(t *testing.T) { + ctx, workspaceFiler, targetDir := setupWorkspaceImportExportTest(t) + var err error + + // create preexisting files in the workspace + err = workspaceFiler.Write(ctx, "file-a", strings.NewReader("old file")) + require.NoError(t, err) + err = workspaceFiler.Write(ctx, "pyNotebook.py", strings.NewReader("# Databricks notebook source\nprint(\"old notebook\")")) + require.NoError(t, err) + + // Assert contents of pre existing files + assertFilerFileContents(t, ctx, workspaceFiler, "file-a", "old file") + assertFilerFileContents(t, ctx, workspaceFiler, "pyNotebook", "# Databricks notebook source\nprint(\"old notebook\")") + + RequireSuccessfulRun(t, "workspace", "import-dir", "./testdata/import_dir", targetDir) + + // Assert files are imported + assertFilerFileContents(t, ctx, workspaceFiler, "a/b/c/file-b", "file-in-dir") + assertFilerFileContents(t, ctx, workspaceFiler, "sqlNotebook", "-- Databricks notebook source\nSELECT \"sql\"") + assertFilerFileContents(t, ctx, workspaceFiler, "rNotebook", "# Databricks notebook source\nprint(\"r\")") + assertFilerFileContents(t, ctx, workspaceFiler, "scalaNotebook", "// Databricks notebook source\nprintln(\"scala\")") + assertFilerFileContents(t, ctx, workspaceFiler, "jupyterNotebook", "# Databricks notebook source\nprint(\"jupyter\")") + + // Assert pre existing files are not changed + assertFilerFileContents(t, ctx, workspaceFiler, "file-a", "old file") + assertFilerFileContents(t, ctx, workspaceFiler, "pyNotebook", "# Databricks notebook source\nprint(\"old notebook\")") +} + +func TestAccImportDirWithOverwriteFlag(t *testing.T) { + ctx, workspaceFiler, targetDir := setupWorkspaceImportExportTest(t) + var err error + + // create preexisting files in the workspace + err = workspaceFiler.Write(ctx, "file-a", strings.NewReader("old file")) + require.NoError(t, err) + err = workspaceFiler.Write(ctx, "pyNotebook.py", strings.NewReader("# Databricks notebook source\nprint(\"old notebook\")")) + require.NoError(t, err) + + // Assert contents of pre existing files + assertFilerFileContents(t, ctx, workspaceFiler, "file-a", "old file") + assertFilerFileContents(t, ctx, workspaceFiler, "pyNotebook", "# Databricks notebook source\nprint(\"old notebook\")") + + RequireSuccessfulRun(t, "workspace", "import-dir", "./testdata/import_dir", targetDir, "--overwrite") + + // Assert files are imported + assertFilerFileContents(t, ctx, workspaceFiler, "a/b/c/file-b", "file-in-dir") + assertFilerFileContents(t, ctx, workspaceFiler, "sqlNotebook", "-- Databricks notebook source\nSELECT \"sql\"") + assertFilerFileContents(t, ctx, workspaceFiler, "rNotebook", "# Databricks notebook source\nprint(\"r\")") + assertFilerFileContents(t, ctx, workspaceFiler, "scalaNotebook", "// Databricks notebook source\nprintln(\"scala\")") + assertFilerFileContents(t, ctx, workspaceFiler, "jupyterNotebook", "# Databricks notebook source\nprint(\"jupyter\")") + + // Assert pre existing files are overwritten + assertFilerFileContents(t, ctx, workspaceFiler, "file-a", "hello, world") + assertFilerFileContents(t, ctx, workspaceFiler, "pyNotebook", "# Databricks notebook source\nprint(\"python\")") } diff --git a/libs/filer/workspace_files_client.go b/libs/filer/workspace_files_client.go index 42759070..2b5e718b 100644 --- a/libs/filer/workspace_files_client.go +++ b/libs/filer/workspace_files_client.go @@ -11,6 +11,7 @@ import ( "net/http" "net/url" "path" + "regexp" "sort" "strings" "time" @@ -144,11 +145,24 @@ func (w *WorkspaceFilesClient) Write(ctx context.Context, name string, reader io return w.Write(ctx, name, bytes.NewReader(body), sliceWithout(mode, CreateParentDirectories)...) } - // This API returns 409 if the file already exists. + // This API returns 409 if the file already exists, when the object type is file if aerr.StatusCode == http.StatusConflict { return FileAlreadyExistsError{absPath} } + // This API returns 400 if the file already exists, when the object type is notebook + regex := regexp.MustCompile(`Path \((.*)\) already exists.`) + if aerr.StatusCode == http.StatusBadRequest && regex.Match([]byte(aerr.Message)) { + // Parse file path from regex capture group + matches := regex.FindStringSubmatch(aerr.Message) + if len(matches) == 2 { + return FileAlreadyExistsError{matches[1]} + } + + // Default to path specified to filer.Write if regex capture fails + return FileAlreadyExistsError{absPath} + } + return err }