diff --git a/internal/filer_test.go b/internal/filer_test.go index bc4c94808..207ce68e8 100644 --- a/internal/filer_test.go +++ b/internal/filer_test.go @@ -736,6 +736,44 @@ func TestAccWorkspaceFilesExtensionsDirectoriesAreNotNotebooks(t *testing.T) { assert.ErrorIs(t, err, fs.ErrNotExist) } +func TestAccWorkspaceFilesExtensionsNotebooksAreNotReadAsFiles(t *testing.T) { + t.Parallel() + + ctx := context.Background() + wf, _ := setupWsfsExtensionsFiler(t) + + // Create a notebook + err := wf.Write(ctx, "foo.ipynb", strings.NewReader(jupyterNotebookContent1)) + require.NoError(t, err) + + // Reading foo should fail. Even though the WSFS name for the notebook is foo + // reading the notebook should only work with the .ipynb extension. + _, err = wf.Read(ctx, "foo") + assert.ErrorIs(t, err, fs.ErrNotExist) + + _, err = wf.Read(ctx, "foo.ipynb") + assert.NoError(t, err) +} + +func TestAccWorkspaceFilesExtensionsNotebooksAreNotStatAsFiles(t *testing.T) { + t.Parallel() + + ctx := context.Background() + wf, _ := setupWsfsExtensionsFiler(t) + + // Create a notebook + err := wf.Write(ctx, "foo.ipynb", strings.NewReader(jupyterNotebookContent1)) + require.NoError(t, err) + + // Reading foo should fail. Even though the WSFS name for the notebook is foo + // reading the notebook should only work with the .ipynb extension. + _, err = wf.Stat(ctx, "foo") + assert.ErrorIs(t, err, fs.ErrNotExist) + + _, err = wf.Read(ctx, "foo.ipynb") + assert.NoError(t, err) +} + func TestAccWorkspaceFilesExtensions_ExportFormatIsPreserved(t *testing.T) { t.Parallel() diff --git a/internal/helpers.go b/internal/helpers.go index 3bf387757..2c917127c 100644 --- a/internal/helpers.go +++ b/internal/helpers.go @@ -581,11 +581,10 @@ func setupWsfsFiler(t *testing.T) (filer.Filer, string) { } func setupWsfsExtensionsFiler(t *testing.T) (filer.Filer, string) { - t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV")) + _, wt := acc.WorkspaceTest(t) - w := databricks.Must(databricks.NewWorkspaceClient()) - tmpdir := TemporaryWorkspaceDir(t, w) - f, err := filer.NewWorkspaceFilesExtensionsClient(w, tmpdir) + tmpdir := TemporaryWorkspaceDir(t, wt.W) + f, err := filer.NewWorkspaceFilesExtensionsClient(wt.W, tmpdir) require.NoError(t, err) return f, tmpdir diff --git a/libs/filer/workspace_files_extensions_client.go b/libs/filer/workspace_files_extensions_client.go index b24ecf7ee..ace079968 100644 --- a/libs/filer/workspace_files_extensions_client.go +++ b/libs/filer/workspace_files_extensions_client.go @@ -245,6 +245,17 @@ func (w *workspaceFilesExtensionsClient) Write(ctx context.Context, name string, // Try to read the file as a regular file. If the file is not found, try to read it as a notebook. func (w *workspaceFilesExtensionsClient) Read(ctx context.Context, name string) (io.ReadCloser, error) { + // Ensure that the file / notebook exists. We do this check here to avoid reading + // the content of a notebook called `foo` when the user actually wanted + // to read the content of a file called `foo`. + // + // To read the content of a notebook called `foo` in the workspace the user + // should use the name with the extension included like `foo.ipynb` or `foo.sql`. + _, err := w.Stat(ctx, name) + if err != nil { + return nil, err + } + r, err := w.wsfs.Read(ctx, name) // If the file is not found, it might be a notebook. @@ -301,6 +312,22 @@ func (w *workspaceFilesExtensionsClient) Delete(ctx context.Context, name string func (w *workspaceFilesExtensionsClient) Stat(ctx context.Context, name string) (fs.FileInfo, error) { info, err := w.wsfs.Stat(ctx, name) + // If an object is found and it is not a notebook, return the stat object. This check is done + // to avoid returning the stat for a notebook called `foo` when the user actually + // wanted to stat a file called `foo`. + // + // To stat the metadata of a notebook called `foo` in the workspace the user + // should use the name with the extension included like `foo.ipynb` or `foo.sql`. + if err == nil && info.Sys().(workspace.ObjectInfo).ObjectType != workspace.ObjectTypeNotebook { + return info, nil + } + + // If a notebook is found by the workspace files client, without having stripped + // the extension, this implies that no file with the same name exists. + if err == nil && info.Sys().(workspace.ObjectInfo).ObjectType == workspace.ObjectTypeNotebook { + return nil, FileDoesNotExistError{name} + } + // If the file is not found, it might be a notebook. if errors.As(err, &FileDoesNotExistError{}) { stat, serr := w.getNotebookStatByNameWithExt(ctx, name) @@ -316,7 +343,7 @@ func (w *workspaceFilesExtensionsClient) Stat(ctx context.Context, name string) return wsfsFileInfo{ObjectInfo: stat.ObjectInfo}, nil } - return info, err + return nil, err } // Note: The import API returns opaque internal errors for namespace clashes