diff --git a/integration/libs/filer/filer_test.go b/integration/libs/filer/filer_test.go index bc1713b30..21f3c9d8b 100644 --- a/integration/libs/filer/filer_test.go +++ b/integration/libs/filer/filer_test.go @@ -6,7 +6,9 @@ import ( "encoding/json" "io" "io/fs" + "os" "path" + "path/filepath" "strings" "testing" @@ -887,3 +889,80 @@ func TestWorkspaceFilesExtensions_ExportFormatIsPreserved(t *testing.T) { }) } } + +func TestDbfsFilerForStreamingUploads(t *testing.T) { + ctx := context.Background() + f, _ := setupDbfsFiler(t) + + // Set MaxDbfsPutFileSize to 1 to force streaming uploads + prevV := filer.MaxDbfsPutFileSize + filer.MaxDbfsPutFileSize = 1 + t.Cleanup(func() { + filer.MaxDbfsPutFileSize = prevV + }) + + // Write a file to local disk. + tmpDir := t.TempDir() + testutil.WriteFile(t, filepath.Join(tmpDir, "foo.txt"), "foobar") + + fd, err := os.Open(filepath.Join(tmpDir, "foo.txt")) + require.NoError(t, err) + defer fd.Close() + + // Write a file with streaming upload + err = f.Write(ctx, "foo.txt", fd) + require.NoError(t, err) + + // Assert contents + filerTest{t, f}.assertContents(ctx, "foo.txt", "foobar") + + // Overwrite the file with streaming upload, and fail + err = f.Write(ctx, "foo.txt", strings.NewReader("barfoo")) + require.ErrorIs(t, err, fs.ErrExist) + + // Overwrite the file with streaming upload, and succeed + err = f.Write(ctx, "foo.txt", strings.NewReader("barfoo"), filer.OverwriteIfExists) + require.NoError(t, err) + + // Assert contents + filerTest{t, f}.assertContents(ctx, "foo.txt", "barfoo") +} + +func TestDbfsFilerForPutUploads(t *testing.T) { + ctx := context.Background() + f, _ := setupDbfsFiler(t) + + // Write a file to local disk. + tmpDir := t.TempDir() + testutil.WriteFile(t, filepath.Join(tmpDir, "foo.txt"), "foobar") + testutil.WriteFile(t, filepath.Join(tmpDir, "bar.txt"), "barfoo") + fdFoo, err := os.Open(filepath.Join(tmpDir, "foo.txt")) + require.NoError(t, err) + defer fdFoo.Close() + + fdBar, err := os.Open(filepath.Join(tmpDir, "bar.txt")) + require.NoError(t, err) + defer fdBar.Close() + + // Write a file with PUT upload + err = f.Write(ctx, "foo.txt", fdFoo) + require.NoError(t, err) + + // Assert contents + filerTest{t, f}.assertContents(ctx, "foo.txt", "foobar") + + // Try to overwrite the file, and fail. + err = f.Write(ctx, "foo.txt", fdBar) + require.ErrorIs(t, err, fs.ErrExist) + + // Reset the file descriptor. + _, err = fdBar.Seek(0, io.SeekStart) + require.NoError(t, err) + + // Overwrite the file with OverwriteIfExists flag + err = f.Write(ctx, "foo.txt", fdBar, filer.OverwriteIfExists) + require.NoError(t, err) + + // Assert contents + filerTest{t, f}.assertContents(ctx, "foo.txt", "barfoo") +} diff --git a/libs/filer/dbfs_client.go b/libs/filer/dbfs_client.go index 38e8f9f3f..6f3b2fec0 100644 --- a/libs/filer/dbfs_client.go +++ b/libs/filer/dbfs_client.go @@ -1,11 +1,15 @@ package filer import ( + "bytes" "context" "errors" + "fmt" "io" "io/fs" + "mime/multipart" "net/http" + "os" "path" "slices" "sort" @@ -14,6 +18,7 @@ import ( "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/client" "github.com/databricks/databricks-sdk-go/service/files" ) @@ -63,33 +68,142 @@ func (info dbfsFileInfo) Sys() any { return info.fi } +// Interface to allow mocking of the Databricks API client. +type databricksClient interface { + Do(ctx context.Context, method, path string, headers map[string]string, + requestBody, responseBody any, visitors ...func(*http.Request) error) error +} + // DbfsClient implements the [Filer] interface for the DBFS backend. type DbfsClient struct { workspaceClient *databricks.WorkspaceClient + apiClient databricksClient + // File operations will be relative to this path. root WorkspaceRootPath } func NewDbfsClient(w *databricks.WorkspaceClient, root string) (Filer, error) { + apiClient, err := client.New(w.Config) + if err != nil { + return nil, fmt.Errorf("failed to create API client: %w", err) + } + return &DbfsClient{ workspaceClient: w, + apiClient: apiClient, root: NewWorkspaceRootPath(root), }, nil } +// The PUT API for DBFS requires setting the content length header beforehand in the HTTP +// request. +func contentLength(path, overwriteField string, file *os.File) (int64, error) { + buf := &bytes.Buffer{} + writer := multipart.NewWriter(buf) + err := writer.WriteField("path", path) + if err != nil { + return 0, fmt.Errorf("failed to write field path field in multipart form: %w", err) + } + err = writer.WriteField("overwrite", overwriteField) + if err != nil { + return 0, fmt.Errorf("failed to write field overwrite field in multipart form: %w", err) + } + _, err = writer.CreateFormFile("contents", "") + if err != nil { + return 0, fmt.Errorf("failed to write contents field in multipart form: %w", err) + } + err = writer.Close() + if err != nil { + return 0, fmt.Errorf("failed to close multipart form writer: %w", err) + } + + stat, err := file.Stat() + if err != nil { + return 0, fmt.Errorf("failed to stat file %s: %w", path, err) + } + + return int64(buf.Len()) + stat.Size(), nil +} + +func contentLengthVisitor(path, overwriteField string, file *os.File) func(*http.Request) error { + return func(r *http.Request) error { + cl, err := contentLength(path, overwriteField, file) + if err != nil { + return fmt.Errorf("failed to calculate content length: %w", err) + } + r.ContentLength = cl + return nil + } +} + +func (w *DbfsClient) putFile(ctx context.Context, path string, overwrite bool, file *os.File) error { + overwriteField := "False" + if overwrite { + overwriteField = "True" + } + + pr, pw := io.Pipe() + writer := multipart.NewWriter(pw) + go func() { + defer pw.Close() + + err := writer.WriteField("path", path) + if err != nil { + pw.CloseWithError(fmt.Errorf("failed to write field path field in multipart form: %w", err)) + return + } + err = writer.WriteField("overwrite", overwriteField) + if err != nil { + pw.CloseWithError(fmt.Errorf("failed to write field overwrite field in multipart form: %w", err)) + return + } + contents, err := writer.CreateFormFile("contents", "") + if err != nil { + pw.CloseWithError(fmt.Errorf("failed to write contents field in multipart form: %w", err)) + return + } + _, err = io.Copy(contents, file) + if err != nil { + pw.CloseWithError(fmt.Errorf("error while streaming file to dbfs: %w", err)) + return + } + err = writer.Close() + if err != nil { + pw.CloseWithError(fmt.Errorf("failed to close multipart form writer: %w", err)) + return + } + }() + + // Request bodies of Content-Type multipart/form-data are not supported by + // the Go SDK directly for DBFS. So we use the Do method directly. + err := w.apiClient.Do(ctx, + http.MethodPost, + "/api/2.0/dbfs/put", + map[string]string{"Content-Type": writer.FormDataContentType()}, + pr, + nil, + contentLengthVisitor(path, overwriteField, file)) + var aerr *apierr.APIError + if errors.As(err, &aerr) && aerr.ErrorCode == "RESOURCE_ALREADY_EXISTS" { + return FileAlreadyExistsError{path} + } + return err +} + +// MaxUploadLimitForPutApi is the maximum size in bytes of a file that can be uploaded +// using the /dbfs/put API. If the file is larger than this limit, the streaming +// API (/dbfs/create and /dbfs/add-block) will be used instead. +var MaxDbfsPutFileSize int64 = 2 * 1024 * 1024 * 1024 + func (w *DbfsClient) Write(ctx context.Context, name string, reader io.Reader, mode ...WriteMode) error { absPath, err := w.root.Join(name) if err != nil { return err } - fileMode := files.FileModeWrite - if slices.Contains(mode, OverwriteIfExists) { - fileMode |= files.FileModeOverwrite - } - // Issue info call before write because it automatically creates parent directories. // // For discussion: we could decide this is actually convenient, remove the call below, @@ -114,7 +228,36 @@ func (w *DbfsClient) Write(ctx context.Context, name string, reader io.Reader, m } } - handle, err := w.workspaceClient.Dbfs.Open(ctx, absPath, fileMode) + localFile, ok := reader.(*os.File) + + // If the source is not a local file, we'll always use the streaming API endpoint. + if !ok { + return w.streamFile(ctx, absPath, slices.Contains(mode, OverwriteIfExists), reader) + } + + stat, err := localFile.Stat() + if err != nil { + return fmt.Errorf("failed to stat file: %w", err) + } + + // If the source is a local file, but is too large then we'll use the streaming API endpoint. + if stat.Size() > MaxDbfsPutFileSize { + return w.streamFile(ctx, absPath, slices.Contains(mode, OverwriteIfExists), localFile) + } + + // Use the /dbfs/put API when the file is on the local filesystem + // and is small enough. This is the most common case when users use the + // `databricks fs cp` command. + return w.putFile(ctx, absPath, slices.Contains(mode, OverwriteIfExists), localFile) +} + +func (w *DbfsClient) streamFile(ctx context.Context, path string, overwrite bool, reader io.Reader) error { + fileMode := files.FileModeWrite + if overwrite { + fileMode |= files.FileModeOverwrite + } + + handle, err := w.workspaceClient.Dbfs.Open(ctx, path, fileMode) if err != nil { var aerr *apierr.APIError if !errors.As(err, &aerr) { @@ -124,7 +267,7 @@ func (w *DbfsClient) Write(ctx context.Context, name string, reader io.Reader, m // This API returns a 400 if the file already exists. if aerr.StatusCode == http.StatusBadRequest { if aerr.ErrorCode == "RESOURCE_ALREADY_EXISTS" { - return FileAlreadyExistsError{absPath} + return FileAlreadyExistsError{path} } } @@ -136,7 +279,6 @@ func (w *DbfsClient) Write(ctx context.Context, name string, reader io.Reader, m if err == nil { err = cerr } - return err } diff --git a/libs/filer/dbfs_client_test.go b/libs/filer/dbfs_client_test.go new file mode 100644 index 000000000..df962e5a3 --- /dev/null +++ b/libs/filer/dbfs_client_test.go @@ -0,0 +1,155 @@ +package filer + +import ( + "context" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/files" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockDbfsApiClient struct { + t testutil.TestingT + isCalled bool +} + +func (m *mockDbfsApiClient) Do(ctx context.Context, method, path string, + headers map[string]string, request, response any, + visitors ...func(*http.Request) error, +) error { + m.isCalled = true + + require.Equal(m.t, "POST", method) + require.Equal(m.t, "/api/2.0/dbfs/put", path) + require.Contains(m.t, headers["Content-Type"], "multipart/form-data; boundary=") + contents, err := io.ReadAll(request.(io.Reader)) + require.NoError(m.t, err) + require.Contains(m.t, string(contents), "hello world") + return nil +} + +func TestDbfsClientForSmallFiles(t *testing.T) { + // write file to local disk + tmp := t.TempDir() + localPath := filepath.Join(tmp, "hello.txt") + testutil.WriteFile(t, localPath, "hello world") + + // setup DBFS client with mocks + m := mocks.NewMockWorkspaceClient(t) + mockApiClient := &mockDbfsApiClient{t: t} + dbfsClient := DbfsClient{ + apiClient: mockApiClient, + workspaceClient: m.WorkspaceClient, + root: NewWorkspaceRootPath("dbfs:/a/b/c"), + } + + m.GetMockDbfsAPI().EXPECT().GetStatusByPath(mock.Anything, "dbfs:/a/b/c").Return(nil, nil) + + // write file to DBFS + fd, err := os.Open(localPath) + require.NoError(t, err) + defer fd.Close() + + err = dbfsClient.Write(context.Background(), "hello.txt", fd) + require.NoError(t, err) + + // verify mock API client is called + require.True(t, mockApiClient.isCalled) +} + +type mockDbfsHandle struct { + builder strings.Builder +} + +func (h *mockDbfsHandle) Read(data []byte) (n int, err error) { return 0, nil } +func (h *mockDbfsHandle) Close() error { return nil } +func (h *mockDbfsHandle) WriteTo(w io.Writer) (n int64, err error) { return 0, nil } + +func (h *mockDbfsHandle) ReadFrom(r io.Reader) (n int64, err error) { + b, err := io.ReadAll(r) + if err != nil { + return 0, err + } + num, err := h.builder.Write(b) + return int64(num), err +} + +func (h *mockDbfsHandle) Write(data []byte) (n int, err error) { + return h.builder.Write(data) +} + +func TestDbfsClientForLargerFiles(t *testing.T) { + // write file to local disk + tmp := t.TempDir() + localPath := filepath.Join(tmp, "hello.txt") + testutil.WriteFile(t, localPath, "hello world") + + // Modify the max file size to 1 byte to simulate + // a large file that needs to be uploaded in chunks. + oldV := MaxDbfsPutFileSize + MaxDbfsPutFileSize = 1 + t.Cleanup(func() { + MaxDbfsPutFileSize = oldV + }) + + // setup DBFS client with mocks + m := mocks.NewMockWorkspaceClient(t) + mockApiClient := &mockDbfsApiClient{t: t} + dbfsClient := DbfsClient{ + apiClient: mockApiClient, + workspaceClient: m.WorkspaceClient, + root: NewWorkspaceRootPath("dbfs:/a/b/c"), + } + + h := &mockDbfsHandle{} + m.GetMockDbfsAPI().EXPECT().GetStatusByPath(mock.Anything, "dbfs:/a/b/c").Return(nil, nil) + m.GetMockDbfsAPI().EXPECT().Open(mock.Anything, "dbfs:/a/b/c/hello.txt", files.FileModeWrite).Return(h, nil) + + // write file to DBFS + fd, err := os.Open(localPath) + require.NoError(t, err) + defer fd.Close() + + err = dbfsClient.Write(context.Background(), "hello.txt", fd) + require.NoError(t, err) + + // verify mock API client is NOT called + require.False(t, mockApiClient.isCalled) + + // verify the file content was written to the mock handle + assert.Equal(t, "hello world", h.builder.String()) +} + +func TestDbfsClientForNonLocalFiles(t *testing.T) { + // setup DBFS client with mocks + m := mocks.NewMockWorkspaceClient(t) + mockApiClient := &mockDbfsApiClient{t: t} + dbfsClient := DbfsClient{ + apiClient: mockApiClient, + workspaceClient: m.WorkspaceClient, + root: NewWorkspaceRootPath("dbfs:/a/b/c"), + } + + h := &mockDbfsHandle{} + m.GetMockDbfsAPI().EXPECT().GetStatusByPath(mock.Anything, "dbfs:/a/b/c").Return(nil, nil) + m.GetMockDbfsAPI().EXPECT().Open(mock.Anything, "dbfs:/a/b/c/hello.txt", files.FileModeWrite).Return(h, nil) + + // write file to DBFS + err := dbfsClient.Write(context.Background(), "hello.txt", strings.NewReader("hello world")) + require.NoError(t, err) + + // verify mock API client is NOT called + require.False(t, mockApiClient.isCalled) + + // verify the file content was written to the mock handle + assert.Equal(t, "hello world", h.builder.String()) +}