mirror of https://github.com/databricks/cli.git
158 lines
4.5 KiB
Go
158 lines
4.5 KiB
Go
package filer
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/databricks/cli/internal/testutil"
|
|
"github.com/databricks/databricks-sdk-go/experimental/mocks"
|
|
"github.com/databricks/databricks-sdk-go/service/files"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type mockDbfsApiClient struct {
|
|
t testutil.TestingT
|
|
isCalled bool
|
|
}
|
|
|
|
func (m *mockDbfsApiClient) Do(ctx context.Context, method, path string,
|
|
headers map[string]string, request, response any,
|
|
visitors ...func(*http.Request) error,
|
|
) error {
|
|
m.isCalled = true
|
|
|
|
require.Equal(m.t, "POST", method)
|
|
require.Equal(m.t, "/api/2.0/dbfs/put", path)
|
|
require.Contains(m.t, headers["Content-Type"], "multipart/form-data; boundary=")
|
|
contents, err := io.ReadAll(request.(io.Reader))
|
|
require.NoError(m.t, err)
|
|
require.Contains(m.t, string(contents), "hello world")
|
|
return nil
|
|
}
|
|
|
|
func TestDbfsClientForSmallFiles(t *testing.T) {
|
|
// write file to local disk
|
|
tmp := t.TempDir()
|
|
localPath := filepath.Join(tmp, "hello.txt")
|
|
err := os.WriteFile(localPath, []byte("hello world"), 0o644)
|
|
require.NoError(t, err)
|
|
|
|
// setup DBFS client with mocks
|
|
m := mocks.NewMockWorkspaceClient(t)
|
|
mockApiClient := &mockDbfsApiClient{t: t}
|
|
dbfsClient := DbfsClient{
|
|
apiClient: mockApiClient,
|
|
workspaceClient: m.WorkspaceClient,
|
|
root: NewWorkspaceRootPath("dbfs:/a/b/c"),
|
|
}
|
|
|
|
m.GetMockDbfsAPI().EXPECT().GetStatusByPath(mock.Anything, "dbfs:/a/b/c").Return(nil, nil)
|
|
|
|
// write file to DBFS
|
|
fd, err := os.Open(localPath)
|
|
require.NoError(t, err)
|
|
defer fd.Close()
|
|
|
|
err = dbfsClient.Write(context.Background(), "hello.txt", fd)
|
|
require.NoError(t, err)
|
|
|
|
// verify mock API client is called
|
|
require.True(t, mockApiClient.isCalled)
|
|
}
|
|
|
|
type mockDbfsHandle struct {
|
|
builder strings.Builder
|
|
}
|
|
|
|
func (h *mockDbfsHandle) Read(data []byte) (n int, err error) { return 0, nil }
|
|
func (h *mockDbfsHandle) Close() error { return nil }
|
|
func (h *mockDbfsHandle) WriteTo(w io.Writer) (n int64, err error) { return 0, nil }
|
|
|
|
func (h *mockDbfsHandle) ReadFrom(r io.Reader) (n int64, err error) {
|
|
b, err := io.ReadAll(r)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
num, err := h.builder.Write(b)
|
|
return int64(num), err
|
|
}
|
|
|
|
func (h *mockDbfsHandle) Write(data []byte) (n int, err error) {
|
|
return h.builder.Write(data)
|
|
}
|
|
|
|
func TestDbfsClientForLargerFiles(t *testing.T) {
|
|
// write file to local disk
|
|
tmp := t.TempDir()
|
|
localPath := filepath.Join(tmp, "hello.txt")
|
|
err := os.WriteFile(localPath, []byte("hello world"), 0o644)
|
|
require.NoError(t, err)
|
|
|
|
// Modify the max file size to 1 byte to simulate
|
|
// a large file that needs to be uploaded in chunks.
|
|
oldV := MaxDbfsPutFileSize
|
|
MaxDbfsPutFileSize = 1
|
|
t.Cleanup(func() {
|
|
MaxDbfsPutFileSize = oldV
|
|
})
|
|
|
|
// setup DBFS client with mocks
|
|
m := mocks.NewMockWorkspaceClient(t)
|
|
mockApiClient := &mockDbfsApiClient{t: t}
|
|
dbfsClient := DbfsClient{
|
|
apiClient: mockApiClient,
|
|
workspaceClient: m.WorkspaceClient,
|
|
root: NewWorkspaceRootPath("dbfs:/a/b/c"),
|
|
}
|
|
|
|
h := &mockDbfsHandle{}
|
|
m.GetMockDbfsAPI().EXPECT().GetStatusByPath(mock.Anything, "dbfs:/a/b/c").Return(nil, nil)
|
|
m.GetMockDbfsAPI().EXPECT().Open(mock.Anything, "dbfs:/a/b/c/hello.txt", files.FileModeWrite).Return(h, nil)
|
|
|
|
// write file to DBFS
|
|
fd, err := os.Open(localPath)
|
|
require.NoError(t, err)
|
|
defer fd.Close()
|
|
|
|
err = dbfsClient.Write(context.Background(), "hello.txt", fd)
|
|
require.NoError(t, err)
|
|
|
|
// verify mock API client is NOT called
|
|
require.False(t, mockApiClient.isCalled)
|
|
|
|
// verify the file content was written to the mock handle
|
|
assert.Equal(t, "hello world", h.builder.String())
|
|
}
|
|
|
|
func TestDbfsClientForNonLocalFiles(t *testing.T) {
|
|
// setup DBFS client with mocks
|
|
m := mocks.NewMockWorkspaceClient(t)
|
|
mockApiClient := &mockDbfsApiClient{t: t}
|
|
dbfsClient := DbfsClient{
|
|
apiClient: mockApiClient,
|
|
workspaceClient: m.WorkspaceClient,
|
|
root: NewWorkspaceRootPath("dbfs:/a/b/c"),
|
|
}
|
|
|
|
h := &mockDbfsHandle{}
|
|
m.GetMockDbfsAPI().EXPECT().GetStatusByPath(mock.Anything, "dbfs:/a/b/c").Return(nil, nil)
|
|
m.GetMockDbfsAPI().EXPECT().Open(mock.Anything, "dbfs:/a/b/c/hello.txt", files.FileModeWrite).Return(h, nil)
|
|
|
|
// write file to DBFS
|
|
err := dbfsClient.Write(context.Background(), "hello.txt", strings.NewReader("hello world"))
|
|
require.NoError(t, err)
|
|
|
|
// verify mock API client is NOT called
|
|
require.False(t, mockApiClient.isCalled)
|
|
|
|
// verify the file content was written to the mock handle
|
|
assert.Equal(t, "hello world", h.builder.String())
|
|
}
|