diff --git a/cmd/fs/helpers.go b/cmd/fs/helpers.go index aecee9e0..ca85cf1b 100644 --- a/cmd/fs/helpers.go +++ b/cmd/fs/helpers.go @@ -10,47 +10,38 @@ import ( "github.com/databricks/cli/libs/filer" ) -type Scheme string - -const ( - DbfsScheme = Scheme("dbfs") - LocalScheme = Scheme("file") - NoScheme = Scheme("") -) - func filerForPath(ctx context.Context, fullPath string) (filer.Filer, string, error) { - parts := strings.SplitN(fullPath, ":/", 2) + // Split path at : to detect any file schemes + parts := strings.SplitN(fullPath, ":", 2) + + // If no scheme is specified, then local path if len(parts) < 2 { - return nil, "", fmt.Errorf(`no scheme specified for path %s. Please specify scheme "dbfs" or "file". Example: file:/foo/bar or file:/c:/foo/bar`, fullPath) + f, err := filer.NewLocalClient("") + return f, fullPath, err } - scheme := Scheme(parts[0]) + + // On windows systems, paths start with a drive letter. If the scheme + // is a single letter and the OS is windows, then we conclude the path + // is meant to be a local path. + if runtime.GOOS == "windows" && len(parts[0]) == 1 { + f, err := filer.NewLocalClient("") + return f, fullPath, err + } + + if parts[0] != "dbfs" { + return nil, "", fmt.Errorf("invalid scheme: %s", parts[0]) + } + path := parts[1] - switch scheme { - case DbfsScheme: - w := root.WorkspaceClient(ctx) - // If the specified path has the "Volumes" prefix, use the Files API. - if strings.HasPrefix(path, "Volumes/") { - f, err := filer.NewFilesClient(w, "/") - return f, path, err - } - f, err := filer.NewDbfsClient(w, "/") - return f, path, err + w := root.WorkspaceClient(ctx) - case LocalScheme: - if runtime.GOOS == "windows" { - parts := strings.SplitN(path, ":", 2) - if len(parts) < 2 { - return nil, "", fmt.Errorf("no volume specfied for path: %s", path) - } - volume := parts[0] + ":" - relPath := parts[1] - f, err := filer.NewLocalClient(volume) - return f, relPath, err - } - f, err := filer.NewLocalClient("/") + // If the specified path has the "Volumes" prefix, use the Files API. + if strings.HasPrefix(path, "/Volumes/") { + f, err := filer.NewFilesClient(w, "/") return f, path, err - - default: - return nil, "", fmt.Errorf(`unsupported scheme %s specified for path %s. Please specify scheme "dbfs" or "file". Example: file:/foo/bar or file:/c:/foo/bar`, scheme, fullPath) } + + // The file is a dbfs file, and uses the DBFS APIs + f, err := filer.NewDbfsClient(w, "/") + return f, path, err } diff --git a/cmd/fs/helpers_test.go b/cmd/fs/helpers_test.go index 4beda6ca..d86bd46e 100644 --- a/cmd/fs/helpers_test.go +++ b/cmd/fs/helpers_test.go @@ -5,22 +5,58 @@ import ( "runtime" "testing" + "github.com/databricks/cli/libs/filer" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestNotSpecifyingVolumeForWindowsPathErrors(t *testing.T) { +func TestFilerForPathForLocalPaths(t *testing.T) { + tmpDir := t.TempDir() + ctx := context.Background() + + f, path, err := filerForPath(ctx, tmpDir) + assert.NoError(t, err) + assert.Equal(t, tmpDir, path) + + info, err := f.Stat(ctx, path) + require.NoError(t, err) + assert.True(t, info.IsDir()) +} + +func TestFilerForPathForInvalidScheme(t *testing.T) { + ctx := context.Background() + + _, _, err := filerForPath(ctx, "dbf:/a") + assert.ErrorContains(t, err, "invalid scheme") + + _, _, err = filerForPath(ctx, "foo:a") + assert.ErrorContains(t, err, "invalid scheme") + + _, _, err = filerForPath(ctx, "file:/a") + assert.ErrorContains(t, err, "invalid scheme") +} + +func testWindowsFilerForPath(t *testing.T, ctx context.Context, fullPath string) { + f, path, err := filerForPath(ctx, fullPath) + assert.NoError(t, err) + + // Assert path remains unchanged + assert.Equal(t, path, fullPath) + + // Assert local client is created + _, ok := f.(*filer.LocalClient) + assert.True(t, ok) +} + +func TestFilerForWindowsLocalPaths(t *testing.T) { if runtime.GOOS != "windows" { - t.Skip() + t.SkipNow() } ctx := context.Background() - pathWithVolume := `file:/c:/foo/bar` - pathWOVolume := `file:/uno/dos` - - _, path, err := filerForPath(ctx, pathWithVolume) - assert.NoError(t, err) - assert.Equal(t, `/foo/bar`, path) - - _, _, err = filerForPath(ctx, pathWOVolume) - assert.Equal(t, "no volume specfied for path: uno/dos", err.Error()) + testWindowsFilerForPath(t, ctx, `c:\abc`) + testWindowsFilerForPath(t, ctx, `c:abc`) + testWindowsFilerForPath(t, ctx, `d:\abc`) + testWindowsFilerForPath(t, ctx, `d:\abc`) + testWindowsFilerForPath(t, ctx, `f:\abc\ef`) } diff --git a/internal/fs_cp_test.go b/internal/fs_cp_test.go index c9171086..766d6a59 100644 --- a/internal/fs_cp_test.go +++ b/internal/fs_cp_test.go @@ -66,7 +66,7 @@ func setupLocalFiler(t *testing.T) (filer.Filer, string) { f, err := filer.NewLocalClient(tmp) require.NoError(t, err) - return f, path.Join("file:/", filepath.ToSlash(tmp)) + return f, path.Join(filepath.ToSlash(tmp)) } func setupDbfsFiler(t *testing.T) (filer.Filer, string) { @@ -259,21 +259,14 @@ func TestAccFsCpErrorsWhenSourceIsDirWithoutRecursiveFlag(t *testing.T) { tmpDir := temporaryDbfsDir(t, w) _, _, err = RequireErrorRun(t, "fs", "cp", "dbfs:"+tmpDir, "dbfs:/tmp") - assert.Equal(t, fmt.Sprintf("source path %s is a directory. Please specify the --recursive flag", strings.TrimPrefix(tmpDir, "/")), err.Error()) -} - -func TestAccFsCpErrorsOnNoScheme(t *testing.T) { - t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV")) - - _, _, err := RequireErrorRun(t, "fs", "cp", "/a", "/b") - assert.Equal(t, "no scheme specified for path /a. Please specify scheme \"dbfs\" or \"file\". Example: file:/foo/bar or file:/c:/foo/bar", err.Error()) + assert.Equal(t, fmt.Sprintf("source path %s is a directory. Please specify the --recursive flag", tmpDir), err.Error()) } func TestAccFsCpErrorsOnInvalidScheme(t *testing.T) { t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV")) _, _, err := RequireErrorRun(t, "fs", "cp", "dbfs:/a", "https:/b") - assert.Equal(t, "unsupported scheme https specified for path https:/b. Please specify scheme \"dbfs\" or \"file\". Example: file:/foo/bar or file:/c:/foo/bar", err.Error()) + assert.Equal(t, "invalid scheme: https", err.Error()) } func TestAccFsCpSourceIsDirectoryButTargetIsFile(t *testing.T) { diff --git a/libs/filer/dbfs_client.go b/libs/filer/dbfs_client.go index 7e59638a..64eb4b77 100644 --- a/libs/filer/dbfs_client.go +++ b/libs/filer/dbfs_client.go @@ -68,14 +68,14 @@ type DbfsClient struct { workspaceClient *databricks.WorkspaceClient // File operations will be relative to this path. - root RootPath + root WorkspaceRootPath } func NewDbfsClient(w *databricks.WorkspaceClient, root string) (Filer, error) { return &DbfsClient{ workspaceClient: w, - root: NewRootPath(root), + root: NewWorkspaceRootPath(root), }, nil } diff --git a/libs/filer/files_client.go b/libs/filer/files_client.go index 6c1f5a97..ee7587dc 100644 --- a/libs/filer/files_client.go +++ b/libs/filer/files_client.go @@ -60,7 +60,7 @@ type FilesClient struct { apiClient *client.DatabricksClient // File operations will be relative to this path. - root RootPath + root WorkspaceRootPath } func filesNotImplementedError(fn string) error { @@ -77,7 +77,7 @@ func NewFilesClient(w *databricks.WorkspaceClient, root string) (Filer, error) { workspaceClient: w, apiClient: apiClient, - root: NewRootPath(root), + root: NewWorkspaceRootPath(root), }, nil } diff --git a/libs/filer/local_client.go b/libs/filer/local_client.go index 8df59d25..8d960c84 100644 --- a/libs/filer/local_client.go +++ b/libs/filer/local_client.go @@ -13,12 +13,12 @@ import ( // LocalClient implements the [Filer] interface for the local filesystem. type LocalClient struct { // File operations will be relative to this path. - root RootPath + root localRootPath } func NewLocalClient(root string) (Filer, error) { return &LocalClient{ - root: NewRootPath(root), + root: NewLocalRootPath(root), }, nil } diff --git a/libs/filer/local_root_path.go b/libs/filer/local_root_path.go new file mode 100644 index 00000000..15a54263 --- /dev/null +++ b/libs/filer/local_root_path.go @@ -0,0 +1,27 @@ +package filer + +import ( + "fmt" + "path/filepath" + "strings" +) + +type localRootPath struct { + rootPath string +} + +func NewLocalRootPath(root string) localRootPath { + if root == "" { + return localRootPath{""} + } + return localRootPath{filepath.Clean(root)} +} + +func (rp *localRootPath) Join(name string) (string, error) { + absPath := filepath.Join(rp.rootPath, name) + + if !strings.HasPrefix(absPath, rp.rootPath) { + return "", fmt.Errorf("relative path escapes root: %s", name) + } + return absPath, nil +} diff --git a/libs/filer/local_root_path_test.go b/libs/filer/local_root_path_test.go new file mode 100644 index 00000000..1a39c446 --- /dev/null +++ b/libs/filer/local_root_path_test.go @@ -0,0 +1,142 @@ +package filer + +import ( + "path/filepath" + "runtime" + "testing" + + "github.com/stretchr/testify/assert" +) + +func testUnixLocalRootPath(t *testing.T, uncleanRoot string) { + cleanRoot := filepath.Clean(uncleanRoot) + rp := NewLocalRootPath(uncleanRoot) + + remotePath, err := rp.Join("a/b/c") + assert.NoError(t, err) + assert.Equal(t, cleanRoot+"/a/b/c", remotePath) + + remotePath, err = rp.Join("a/b/../d") + assert.NoError(t, err) + assert.Equal(t, cleanRoot+"/a/d", remotePath) + + remotePath, err = rp.Join("a/../c") + assert.NoError(t, err) + assert.Equal(t, cleanRoot+"/c", remotePath) + + remotePath, err = rp.Join("a/b/c/.") + assert.NoError(t, err) + assert.Equal(t, cleanRoot+"/a/b/c", remotePath) + + remotePath, err = rp.Join("a/b/c/d/./../../f/g") + assert.NoError(t, err) + assert.Equal(t, cleanRoot+"/a/b/f/g", remotePath) + + remotePath, err = rp.Join(".//a/..//./b/..") + assert.NoError(t, err) + assert.Equal(t, cleanRoot, remotePath) + + remotePath, err = rp.Join("a/b/../..") + assert.NoError(t, err) + assert.Equal(t, cleanRoot, remotePath) + + remotePath, err = rp.Join("") + assert.NoError(t, err) + assert.Equal(t, cleanRoot, remotePath) + + remotePath, err = rp.Join(".") + assert.NoError(t, err) + assert.Equal(t, cleanRoot, remotePath) + + remotePath, err = rp.Join("/") + assert.NoError(t, err) + assert.Equal(t, cleanRoot, remotePath) + + _, err = rp.Join("..") + assert.ErrorContains(t, err, `relative path escapes root: ..`) + + _, err = rp.Join("a/../..") + assert.ErrorContains(t, err, `relative path escapes root: a/../..`) + + _, err = rp.Join("./../.") + assert.ErrorContains(t, err, `relative path escapes root: ./../.`) + + _, err = rp.Join("/./.././..") + assert.ErrorContains(t, err, `relative path escapes root: /./.././..`) + + _, err = rp.Join("./../.") + assert.ErrorContains(t, err, `relative path escapes root: ./../.`) + + _, err = rp.Join("./..") + assert.ErrorContains(t, err, `relative path escapes root: ./..`) + + _, err = rp.Join("./../../..") + assert.ErrorContains(t, err, `relative path escapes root: ./../../..`) + + _, err = rp.Join("./../a/./b../../..") + assert.ErrorContains(t, err, `relative path escapes root: ./../a/./b../../..`) + + _, err = rp.Join("../..") + assert.ErrorContains(t, err, `relative path escapes root: ../..`) +} + +func TestUnixLocalRootPath(t *testing.T) { + if runtime.GOOS != "darwin" && runtime.GOOS != "linux" { + t.SkipNow() + } + + testUnixLocalRootPath(t, "/some/root/path") + testUnixLocalRootPath(t, "/some/root/path/") + testUnixLocalRootPath(t, "/some/root/path/.") + testUnixLocalRootPath(t, "/some/root/../path/") +} + +func testWindowsLocalRootPath(t *testing.T, uncleanRoot string) { + cleanRoot := filepath.Clean(uncleanRoot) + rp := NewLocalRootPath(uncleanRoot) + + remotePath, err := rp.Join(`a\b\c`) + assert.NoError(t, err) + assert.Equal(t, cleanRoot+`\a\b\c`, remotePath) + + remotePath, err = rp.Join(`a\b\..\d`) + assert.NoError(t, err) + assert.Equal(t, cleanRoot+`\a\d`, remotePath) + + remotePath, err = rp.Join(`a\..\c`) + assert.NoError(t, err) + assert.Equal(t, cleanRoot+`\c`, remotePath) + + remotePath, err = rp.Join(`a\b\c\.`) + assert.NoError(t, err) + assert.Equal(t, cleanRoot+`\a\b\c`, remotePath) + + remotePath, err = rp.Join(`a\b\..\..`) + assert.NoError(t, err) + assert.Equal(t, cleanRoot, remotePath) + + remotePath, err = rp.Join("") + assert.NoError(t, err) + assert.Equal(t, cleanRoot, remotePath) + + remotePath, err = rp.Join(".") + assert.NoError(t, err) + assert.Equal(t, cleanRoot, remotePath) + + _, err = rp.Join("..") + assert.ErrorContains(t, err, `relative path escapes root`) + + _, err = rp.Join(`a\..\..`) + assert.ErrorContains(t, err, `relative path escapes root`) +} + +func TestWindowsLocalRootPath(t *testing.T) { + if runtime.GOOS != "windows" { + t.SkipNow() + } + + testWindowsLocalRootPath(t, `c:\some\root\path`) + testWindowsLocalRootPath(t, `c:\some\root\path\`) + testWindowsLocalRootPath(t, `c:\some\root\path\.`) + testWindowsLocalRootPath(t, `C:\some\root\..\path\`) +} diff --git a/libs/filer/workspace_files_client.go b/libs/filer/workspace_files_client.go index 12a536bf..db06f91c 100644 --- a/libs/filer/workspace_files_client.go +++ b/libs/filer/workspace_files_client.go @@ -78,7 +78,7 @@ type WorkspaceFilesClient struct { apiClient *client.DatabricksClient // File operations will be relative to this path. - root RootPath + root WorkspaceRootPath } func NewWorkspaceFilesClient(w *databricks.WorkspaceClient, root string) (Filer, error) { @@ -91,7 +91,7 @@ func NewWorkspaceFilesClient(w *databricks.WorkspaceClient, root string) (Filer, workspaceClient: w, apiClient: apiClient, - root: NewRootPath(root), + root: NewWorkspaceRootPath(root), }, nil } diff --git a/libs/filer/root_path.go b/libs/filer/workspace_root_path.go similarity index 65% rename from libs/filer/root_path.go rename to libs/filer/workspace_root_path.go index bdeff5d7..d5163924 100644 --- a/libs/filer/root_path.go +++ b/libs/filer/workspace_root_path.go @@ -6,23 +6,23 @@ import ( "strings" ) -// RootPath can be joined with a relative path and ensures that +// WorkspaceRootPath can be joined with a relative path and ensures that // the returned path is always a strict child of the root path. -type RootPath struct { +type WorkspaceRootPath struct { rootPath string } -// NewRootPath constructs and returns [RootPath]. +// NewWorkspaceRootPath constructs and returns [RootPath]. // The named path is cleaned on construction. -func NewRootPath(name string) RootPath { - return RootPath{ +func NewWorkspaceRootPath(name string) WorkspaceRootPath { + return WorkspaceRootPath{ rootPath: path.Clean(name), } } // Join returns the specified path name joined to the root. // It returns an error if the resulting path is not a strict child of the root path. -func (p *RootPath) Join(name string) (string, error) { +func (p *WorkspaceRootPath) Join(name string) (string, error) { absPath := path.Join(p.rootPath, name) // Don't allow escaping the specified root using relative paths. diff --git a/libs/filer/root_path_test.go b/libs/filer/workspace_root_path_test.go similarity index 98% rename from libs/filer/root_path_test.go rename to libs/filer/workspace_root_path_test.go index 965842d0..73746d27 100644 --- a/libs/filer/root_path_test.go +++ b/libs/filer/workspace_root_path_test.go @@ -9,7 +9,7 @@ import ( func testRootPath(t *testing.T, uncleanRoot string) { cleanRoot := path.Clean(uncleanRoot) - rp := NewRootPath(uncleanRoot) + rp := NewWorkspaceRootPath(uncleanRoot) remotePath, err := rp.Join("a/b/c") assert.NoError(t, err)