From 0ce50fadf02ba90abceb3c01f7f1fd5baef6b962 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Tue, 31 Dec 2024 12:09:24 +0530 Subject: [PATCH] add unit test --- libs/filer/dbfs_client_test.go | 124 +++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 libs/filer/dbfs_client_test.go diff --git a/libs/filer/dbfs_client_test.go b/libs/filer/dbfs_client_test.go new file mode 100644 index 000000000..2e8ad86d0 --- /dev/null +++ b/libs/filer/dbfs_client_test.go @@ -0,0 +1,124 @@ +package filer + +import ( + "context" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/files" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockDbfsApiClient struct { + t testutil.TestingT + isCalled bool +} + +func (m *mockDbfsApiClient) Do(ctx context.Context, method, path string, + headers map[string]string, request, response any, + visitors ...func(*http.Request) error, +) error { + m.isCalled = true + + require.Equal(m.t, "POST", method) + require.Equal(m.t, "/api/2.0/dbfs/put", path) + require.Contains(m.t, headers["Content-Type"], "multipart/form-data; boundary=") + require.Contains(m.t, string(request.([]byte)), "hello world") + return nil +} + +func TestDbfsClientForSmallFiles(t *testing.T) { + // write file to local disk + tmp := t.TempDir() + localPath := filepath.Join(tmp, "hello.txt") + err := os.WriteFile(localPath, []byte("hello world"), 0o644) + require.NoError(t, err) + + // setup DBFS client with mocks + m := mocks.NewMockWorkspaceClient(t) + mockApiClient := &mockDbfsApiClient{t: t} + dbfsClient := DbfsClient{ + apiClient: mockApiClient, + workspaceClient: m.WorkspaceClient, + root: NewWorkspaceRootPath("dbfs:/a/b/c"), + } + + m.GetMockDbfsAPI().EXPECT().GetStatusByPath(mock.Anything, "dbfs:/a/b/c").Return(nil, nil) + + // write file to DBFS + fd, err := os.Open(localPath) + require.NoError(t, err) + err = dbfsClient.Write(context.Background(), "hello.txt", fd) + require.NoError(t, err) + + // verify mock API client is called + require.True(t, mockApiClient.isCalled) +} + +type mockDbfsHandle struct { + builder strings.Builder +} + +func (h *mockDbfsHandle) Write(data []byte) (n int, err error) { return 0, nil } +func (h *mockDbfsHandle) Read(data []byte) (n int, err error) { return 0, nil } +func (h *mockDbfsHandle) Close() error { return nil } +func (h *mockDbfsHandle) WriteTo(w io.Writer) (n int64, err error) { return 0, nil } + +func (h *mockDbfsHandle) ReadFrom(r io.Reader) (n int64, err error) { + b, err := io.ReadAll(r) + if err != nil { + return 0, err + } + num, err := h.builder.Write(b) + return int64(num), err +} + +func TestDbfsClientForLargerFiles(t *testing.T) { + // write file to local disk + tmp := t.TempDir() + localPath := filepath.Join(tmp, "hello.txt") + err := os.WriteFile(localPath, []byte("hello world"), 0o644) + require.NoError(t, err) + + // Modify the max file size to 1 byte to simulate + // a large file that needs to be uploaded in chunks. + oldV := MaxDbfsPutFileSize + MaxDbfsPutFileSize = 1 + t.Cleanup(func() { + MaxDbfsPutFileSize = oldV + }) + + // setup DBFS client with mocks + m := mocks.NewMockWorkspaceClient(t) + mockApiClient := &mockDbfsApiClient{t: t} + dbfsClient := DbfsClient{ + apiClient: mockApiClient, + workspaceClient: m.WorkspaceClient, + root: NewWorkspaceRootPath("dbfs:/a/b/c"), + } + + h := &mockDbfsHandle{} + + m.GetMockDbfsAPI().EXPECT().GetStatusByPath(mock.Anything, "dbfs:/a/b/c").Return(nil, nil) + m.GetMockDbfsAPI().EXPECT().Open(mock.Anything, "dbfs:/a/b/c/hello.txt", files.FileModeWrite).Return(h, nil) + + // write file to DBFS + fd, err := os.Open(localPath) + require.NoError(t, err) + err = dbfsClient.Write(context.Background(), "hello.txt", fd) + require.NoError(t, err) + + // verify mock API client is NOT called + require.False(t, mockApiClient.isCalled) + + // verify the file content was written to the mock handle + assert.Equal(t, "hello world", h.builder.String()) +}