From 12aae3551915942a844ac06e809d5053f3aff861 Mon Sep 17 00:00:00 2001 From: Pieter Noordhuis Date: Wed, 14 Dec 2022 15:37:14 +0100 Subject: [PATCH] Abstract over file handling with WSFS or DBFS through filer interface (#135) --- internal/filer_test.go | 112 +++++++++++++++++ libs/filer/filer.go | 44 +++++++ libs/filer/slice.go | 13 ++ libs/filer/slice_test.go | 22 ++++ libs/filer/workspace_files_client.go | 143 ++++++++++++++++++++++ libs/filer/workspace_files_client_test.go | 74 +++++++++++ 6 files changed, 408 insertions(+) create mode 100644 internal/filer_test.go create mode 100644 libs/filer/filer.go create mode 100644 libs/filer/slice.go create mode 100644 libs/filer/slice_test.go create mode 100644 libs/filer/workspace_files_client.go create mode 100644 libs/filer/workspace_files_client_test.go diff --git a/internal/filer_test.go b/internal/filer_test.go new file mode 100644 index 00000000..d8dc8aba --- /dev/null +++ b/internal/filer_test.go @@ -0,0 +1,112 @@ +package internal + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "strings" + "testing" + + "github.com/databricks/bricks/libs/filer" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/service/workspace" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type filerTest struct { + *testing.T + filer.Filer +} + +func (f filerTest) assertContents(ctx context.Context, name string, contents string) { + reader, err := f.Read(ctx, name) + if !assert.NoError(f, err) { + return + } + + body, err := io.ReadAll(reader) + if !assert.NoError(f, err) { + return + } + + assert.Equal(f, contents, string(body)) +} + +func temporaryWorkspaceDir(t *testing.T, w *databricks.WorkspaceClient) string { + ctx := context.Background() + me, err := w.CurrentUser.Me(ctx) + require.NoError(t, err) + + path := fmt.Sprintf("/Users/%s/%s", me.UserName, RandomName("wsfs-files-")) + + // Ensure directory exists, but doesn't exist YET! + // Otherwise we could inadvertently remove a directory that already exists on cleanup. + t.Logf("mkdir %s", path) + err = w.Workspace.MkdirsByPath(ctx, path) + require.NoError(t, err) + + // Remove test directory on test completion. + t.Cleanup(func() { + t.Logf("rm -rf %s", path) + err := w.Workspace.Delete(ctx, workspace.Delete{ + Path: path, + Recursive: true, + }) + if err == nil || apierr.IsMissing(err) { + return + } + t.Logf("unable to remove temporary workspace path %s: %#v", path, err) + }) + + return path +} + +func TestAccFilerWorkspaceFiles(t *testing.T) { + t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV")) + + ctx := context.Background() + w := databricks.Must(databricks.NewWorkspaceClient()) + tmpdir := temporaryWorkspaceDir(t, w) + f, err := filer.NewWorkspaceFilesClient(w, tmpdir) + require.NoError(t, err) + + // Check if we can use this API here, skip test if we cannot. + _, err = f.Read(ctx, "we_use_this_call_to_test_if_this_api_is_enabled") + if apierr, ok := err.(apierr.APIError); ok && apierr.StatusCode == http.StatusBadRequest { + t.Skip(apierr.Message) + } + + // Write should fail because the root path doesn't yet exist. + err = f.Write(ctx, "/foo/bar", strings.NewReader(`"hello world"`)) + assert.True(t, errors.As(err, &filer.NoSuchDirectoryError{})) + + // Read should fail because the root path doesn't yet exist. + _, err = f.Read(ctx, "/foo/bar") + assert.True(t, apierr.IsMissing(err)) + + // Write with CreateParentDirectories flag should succeed. + err = f.Write(ctx, "/foo/bar", strings.NewReader(`"hello world"`), filer.CreateParentDirectories) + assert.NoError(t, err) + filerTest{t, f}.assertContents(ctx, "/foo/bar", `"hello world"`) + + // Write should fail because there is an existing file at the specified path. + err = f.Write(ctx, "/foo/bar", strings.NewReader(`"hello universe"`)) + assert.True(t, errors.As(err, &filer.FileAlreadyExistsError{})) + + // Write with OverwriteIfExists should succeed. + err = f.Write(ctx, "/foo/bar", strings.NewReader(`"hello universe"`), filer.OverwriteIfExists) + assert.NoError(t, err) + filerTest{t, f}.assertContents(ctx, "/foo/bar", `"hello universe"`) + + // Delete should fail if the file doesn't exist. + err = f.Delete(ctx, "/doesnt_exist") + assert.True(t, apierr.IsMissing(err)) + + // Delete should succeed for file that does exist. + err = f.Delete(ctx, "/foo/bar") + assert.NoError(t, err) +} diff --git a/libs/filer/filer.go b/libs/filer/filer.go new file mode 100644 index 00000000..92de6e12 --- /dev/null +++ b/libs/filer/filer.go @@ -0,0 +1,44 @@ +package filer + +import ( + "context" + "fmt" + "io" +) + +type WriteMode int + +const ( + OverwriteIfExists WriteMode = iota + CreateParentDirectories = iota << 1 +) + +type FileAlreadyExistsError struct { + path string +} + +func (err FileAlreadyExistsError) Error() string { + return fmt.Sprintf("file already exists: %s", err.path) +} + +type NoSuchDirectoryError struct { + path string +} + +func (err NoSuchDirectoryError) Error() string { + return fmt.Sprintf("no such directory: %s", err.path) +} + +// Filer is used to access files in a workspace. +// It has implementations for accessing files in WSFS and in DBFS. +type Filer interface { + // Write file at `path`. + // Use the mode to further specify behavior. + Write(ctx context.Context, path string, reader io.Reader, mode ...WriteMode) error + + // Read file at `path`. + Read(ctx context.Context, path string) (io.Reader, error) + + // Delete file at `path`. + Delete(ctx context.Context, path string) error +} diff --git a/libs/filer/slice.go b/libs/filer/slice.go new file mode 100644 index 00000000..c35d6e78 --- /dev/null +++ b/libs/filer/slice.go @@ -0,0 +1,13 @@ +package filer + +import "golang.org/x/exp/slices" + +// sliceWithout returns a copy of the specified slice without element e, if it is present. +func sliceWithout[S []E, E comparable](s S, e E) S { + s_ := slices.Clone(s) + i := slices.Index(s_, e) + if i >= 0 { + s_ = slices.Delete(s_, i, i+1) + } + return s_ +} diff --git a/libs/filer/slice_test.go b/libs/filer/slice_test.go new file mode 100644 index 00000000..21d78348 --- /dev/null +++ b/libs/filer/slice_test.go @@ -0,0 +1,22 @@ +package filer + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSliceWithout(t *testing.T) { + assert.Equal(t, []int{}, sliceWithout([]int{}, 0)) + assert.Equal(t, []int{1, 2, 3}, sliceWithout([]int{1, 2, 3}, 4)) + assert.Equal(t, []int{2, 3}, sliceWithout([]int{1, 2, 3}, 1)) + assert.Equal(t, []int{1, 3}, sliceWithout([]int{1, 2, 3}, 2)) + assert.Equal(t, []int{1, 2}, sliceWithout([]int{1, 2, 3}, 3)) + +} + +func TestSliceWithoutReturnsClone(t *testing.T) { + var ints = []int{1, 2, 3} + assert.Equal(t, []int{2, 3}, sliceWithout(ints, 1)) + assert.Equal(t, []int{1, 2, 3}, ints) +} diff --git a/libs/filer/workspace_files_client.go b/libs/filer/workspace_files_client.go new file mode 100644 index 00000000..03aa1b97 --- /dev/null +++ b/libs/filer/workspace_files_client.go @@ -0,0 +1,143 @@ +package filer + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "path" + "strings" + + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/client" + "github.com/databricks/databricks-sdk-go/service/workspace" + "golang.org/x/exp/slices" +) + +// WorkspaceFilesClient implements the Files-in-Workspace API. +type WorkspaceFilesClient struct { + workspaceClient *databricks.WorkspaceClient + apiClient *client.DatabricksClient + + // File operations will be relative to this path. + root string +} + +func NewWorkspaceFilesClient(w *databricks.WorkspaceClient, root string) (Filer, error) { + apiClient, err := client.New(w.Config) + if err != nil { + return nil, err + } + + return &WorkspaceFilesClient{ + workspaceClient: w, + apiClient: apiClient, + + root: path.Clean(root), + }, nil +} + +func (w *WorkspaceFilesClient) absPath(name string) (string, error) { + absPath := path.Join(w.root, name) + + // Don't allow escaping the specified root using relative paths. + if !strings.HasPrefix(absPath, w.root) { + return "", fmt.Errorf("relative path escapes root: %s", name) + } + + // Don't allow name to resolve to the root path. + if strings.TrimPrefix(absPath, w.root) == "" { + return "", fmt.Errorf("relative path resolves to root: %s", name) + } + + return absPath, nil +} + +func (w *WorkspaceFilesClient) Write(ctx context.Context, name string, reader io.Reader, mode ...WriteMode) error { + absPath, err := w.absPath(name) + if err != nil { + return err + } + + // Remove leading "/" so we can use it in the URL. + overwrite := slices.Contains(mode, OverwriteIfExists) + urlPath := fmt.Sprintf( + "/api/2.0/workspace-files/import-file/%s?overwrite=%t", + strings.TrimLeft(absPath, "/"), + overwrite, + ) + + // Buffer the file contents because we may need to retry below and we cannot read twice. + body, err := io.ReadAll(reader) + if err != nil { + return err + } + + err = w.apiClient.Do(ctx, http.MethodPost, urlPath, body, nil) + + // If we got an API error we deal with it below. + aerr, ok := err.(apierr.APIError) + if !ok { + return err + } + + // This API returns a 404 if the parent directory does not exist. + if aerr.StatusCode == http.StatusNotFound { + if !slices.Contains(mode, CreateParentDirectories) { + return NoSuchDirectoryError{path.Dir(absPath)} + } + + // Create parent directory. + err = w.workspaceClient.Workspace.MkdirsByPath(ctx, path.Dir(absPath)) + if err != nil { + return fmt.Errorf("unable to mkdir to write file %s: %w", absPath, err) + } + + // Retry without CreateParentDirectories mode flag. + return w.Write(ctx, name, bytes.NewReader(body), sliceWithout(mode, CreateParentDirectories)...) + } + + // This API returns 409 if the file already exists. + if aerr.StatusCode == http.StatusConflict { + return FileAlreadyExistsError{absPath} + } + + return err +} + +func (w *WorkspaceFilesClient) Read(ctx context.Context, name string) (io.Reader, error) { + absPath, err := w.absPath(name) + if err != nil { + return nil, err + } + + // Remove leading "/" so we can use it in the URL. + urlPath := fmt.Sprintf( + "/api/2.0/workspace-files/%s", + strings.TrimLeft(absPath, "/"), + ) + + // Update to []byte after https://github.com/databricks/databricks-sdk-go/pull/247 is merged. + var res json.RawMessage + err = w.apiClient.Do(ctx, http.MethodGet, urlPath, nil, &res) + if err != nil { + return nil, err + } + + return bytes.NewReader(res), nil +} + +func (w *WorkspaceFilesClient) Delete(ctx context.Context, name string) error { + absPath, err := w.absPath(name) + if err != nil { + return err + } + + return w.workspaceClient.Workspace.Delete(ctx, workspace.Delete{ + Path: absPath, + Recursive: false, + }) +} diff --git a/libs/filer/workspace_files_client_test.go b/libs/filer/workspace_files_client_test.go new file mode 100644 index 00000000..3700882a --- /dev/null +++ b/libs/filer/workspace_files_client_test.go @@ -0,0 +1,74 @@ +package filer + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWorkspaceFilesClientPaths(t *testing.T) { + root := "/some/root/path" + f := WorkspaceFilesClient{root: root} + + remotePath, err := f.absPath("a/b/c") + assert.NoError(t, err) + assert.Equal(t, root+"/a/b/c", remotePath) + + remotePath, err = f.absPath("a/b/../d") + assert.NoError(t, err) + assert.Equal(t, root+"/a/d", remotePath) + + remotePath, err = f.absPath("a/../c") + assert.NoError(t, err) + assert.Equal(t, root+"/c", remotePath) + + remotePath, err = f.absPath("a/b/c/.") + assert.NoError(t, err) + assert.Equal(t, root+"/a/b/c", remotePath) + + remotePath, err = f.absPath("a/b/c/d/./../../f/g") + assert.NoError(t, err) + assert.Equal(t, root+"/a/b/f/g", remotePath) + + _, err = f.absPath("..") + assert.ErrorContains(t, err, `relative path escapes root: ..`) + + _, err = f.absPath("a/../..") + assert.ErrorContains(t, err, `relative path escapes root: a/../..`) + + _, err = f.absPath("./../.") + assert.ErrorContains(t, err, `relative path escapes root: ./../.`) + + _, err = f.absPath("/./.././..") + assert.ErrorContains(t, err, `relative path escapes root: /./.././..`) + + _, err = f.absPath("./../.") + assert.ErrorContains(t, err, `relative path escapes root: ./../.`) + + _, err = f.absPath("./..") + assert.ErrorContains(t, err, `relative path escapes root: ./..`) + + _, err = f.absPath("./../../..") + assert.ErrorContains(t, err, `relative path escapes root: ./../../..`) + + _, err = f.absPath("./../a/./b../../..") + assert.ErrorContains(t, err, `relative path escapes root: ./../a/./b../../..`) + + _, err = f.absPath("../..") + assert.ErrorContains(t, err, `relative path escapes root: ../..`) + + _, err = f.absPath(".//a/..//./b/..") + assert.ErrorContains(t, err, `relative path resolves to root: .//a/..//./b/..`) + + _, err = f.absPath("a/b/../..") + assert.ErrorContains(t, err, "relative path resolves to root: a/b/../..") + + _, err = f.absPath("") + assert.ErrorContains(t, err, "relative path resolves to root: ") + + _, err = f.absPath(".") + assert.ErrorContains(t, err, "relative path resolves to root: .") + + _, err = f.absPath("/") + assert.ErrorContains(t, err, "relative path resolves to root: /") +}