Make local files default for fs commands (#506)

## Changes
<!-- Summary of your changes that are easy to understand -->

## Tests
<!-- How is this tested? -->
This commit is contained in:
shreyas-goenka 2023-06-23 16:07:09 +02:00 committed by GitHub
parent d0e9953ad9
commit 30efe91c6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 261 additions and 72 deletions

View File

@ -10,47 +10,38 @@ import (
"github.com/databricks/cli/libs/filer" "github.com/databricks/cli/libs/filer"
) )
type Scheme string
const (
DbfsScheme = Scheme("dbfs")
LocalScheme = Scheme("file")
NoScheme = Scheme("")
)
func filerForPath(ctx context.Context, fullPath string) (filer.Filer, string, error) { func filerForPath(ctx context.Context, fullPath string) (filer.Filer, string, error) {
parts := strings.SplitN(fullPath, ":/", 2) // Split path at : to detect any file schemes
parts := strings.SplitN(fullPath, ":", 2)
// If no scheme is specified, then local path
if len(parts) < 2 { if len(parts) < 2 {
return nil, "", fmt.Errorf(`no scheme specified for path %s. Please specify scheme "dbfs" or "file". Example: file:/foo/bar or file:/c:/foo/bar`, fullPath) f, err := filer.NewLocalClient("")
return f, fullPath, err
} }
scheme := Scheme(parts[0])
// On windows systems, paths start with a drive letter. If the scheme
// is a single letter and the OS is windows, then we conclude the path
// is meant to be a local path.
if runtime.GOOS == "windows" && len(parts[0]) == 1 {
f, err := filer.NewLocalClient("")
return f, fullPath, err
}
if parts[0] != "dbfs" {
return nil, "", fmt.Errorf("invalid scheme: %s", parts[0])
}
path := parts[1] path := parts[1]
switch scheme { w := root.WorkspaceClient(ctx)
case DbfsScheme:
w := root.WorkspaceClient(ctx)
// If the specified path has the "Volumes" prefix, use the Files API.
if strings.HasPrefix(path, "Volumes/") {
f, err := filer.NewFilesClient(w, "/")
return f, path, err
}
f, err := filer.NewDbfsClient(w, "/")
return f, path, err
case LocalScheme: // If the specified path has the "Volumes" prefix, use the Files API.
if runtime.GOOS == "windows" { if strings.HasPrefix(path, "/Volumes/") {
parts := strings.SplitN(path, ":", 2) f, err := filer.NewFilesClient(w, "/")
if len(parts) < 2 {
return nil, "", fmt.Errorf("no volume specfied for path: %s", path)
}
volume := parts[0] + ":"
relPath := parts[1]
f, err := filer.NewLocalClient(volume)
return f, relPath, err
}
f, err := filer.NewLocalClient("/")
return f, path, err return f, path, err
default:
return nil, "", fmt.Errorf(`unsupported scheme %s specified for path %s. Please specify scheme "dbfs" or "file". Example: file:/foo/bar or file:/c:/foo/bar`, scheme, fullPath)
} }
// The file is a dbfs file, and uses the DBFS APIs
f, err := filer.NewDbfsClient(w, "/")
return f, path, err
} }

View File

@ -5,22 +5,58 @@ import (
"runtime" "runtime"
"testing" "testing"
"github.com/databricks/cli/libs/filer"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestNotSpecifyingVolumeForWindowsPathErrors(t *testing.T) { func TestFilerForPathForLocalPaths(t *testing.T) {
tmpDir := t.TempDir()
ctx := context.Background()
f, path, err := filerForPath(ctx, tmpDir)
assert.NoError(t, err)
assert.Equal(t, tmpDir, path)
info, err := f.Stat(ctx, path)
require.NoError(t, err)
assert.True(t, info.IsDir())
}
func TestFilerForPathForInvalidScheme(t *testing.T) {
ctx := context.Background()
_, _, err := filerForPath(ctx, "dbf:/a")
assert.ErrorContains(t, err, "invalid scheme")
_, _, err = filerForPath(ctx, "foo:a")
assert.ErrorContains(t, err, "invalid scheme")
_, _, err = filerForPath(ctx, "file:/a")
assert.ErrorContains(t, err, "invalid scheme")
}
func testWindowsFilerForPath(t *testing.T, ctx context.Context, fullPath string) {
f, path, err := filerForPath(ctx, fullPath)
assert.NoError(t, err)
// Assert path remains unchanged
assert.Equal(t, path, fullPath)
// Assert local client is created
_, ok := f.(*filer.LocalClient)
assert.True(t, ok)
}
func TestFilerForWindowsLocalPaths(t *testing.T) {
if runtime.GOOS != "windows" { if runtime.GOOS != "windows" {
t.Skip() t.SkipNow()
} }
ctx := context.Background() ctx := context.Background()
pathWithVolume := `file:/c:/foo/bar` testWindowsFilerForPath(t, ctx, `c:\abc`)
pathWOVolume := `file:/uno/dos` testWindowsFilerForPath(t, ctx, `c:abc`)
testWindowsFilerForPath(t, ctx, `d:\abc`)
_, path, err := filerForPath(ctx, pathWithVolume) testWindowsFilerForPath(t, ctx, `d:\abc`)
assert.NoError(t, err) testWindowsFilerForPath(t, ctx, `f:\abc\ef`)
assert.Equal(t, `/foo/bar`, path)
_, _, err = filerForPath(ctx, pathWOVolume)
assert.Equal(t, "no volume specfied for path: uno/dos", err.Error())
} }

View File

@ -66,7 +66,7 @@ func setupLocalFiler(t *testing.T) (filer.Filer, string) {
f, err := filer.NewLocalClient(tmp) f, err := filer.NewLocalClient(tmp)
require.NoError(t, err) require.NoError(t, err)
return f, path.Join("file:/", filepath.ToSlash(tmp)) return f, path.Join(filepath.ToSlash(tmp))
} }
func setupDbfsFiler(t *testing.T) (filer.Filer, string) { func setupDbfsFiler(t *testing.T) (filer.Filer, string) {
@ -259,21 +259,14 @@ func TestAccFsCpErrorsWhenSourceIsDirWithoutRecursiveFlag(t *testing.T) {
tmpDir := temporaryDbfsDir(t, w) tmpDir := temporaryDbfsDir(t, w)
_, _, err = RequireErrorRun(t, "fs", "cp", "dbfs:"+tmpDir, "dbfs:/tmp") _, _, err = RequireErrorRun(t, "fs", "cp", "dbfs:"+tmpDir, "dbfs:/tmp")
assert.Equal(t, fmt.Sprintf("source path %s is a directory. Please specify the --recursive flag", strings.TrimPrefix(tmpDir, "/")), err.Error()) assert.Equal(t, fmt.Sprintf("source path %s is a directory. Please specify the --recursive flag", tmpDir), err.Error())
}
func TestAccFsCpErrorsOnNoScheme(t *testing.T) {
t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV"))
_, _, err := RequireErrorRun(t, "fs", "cp", "/a", "/b")
assert.Equal(t, "no scheme specified for path /a. Please specify scheme \"dbfs\" or \"file\". Example: file:/foo/bar or file:/c:/foo/bar", err.Error())
} }
func TestAccFsCpErrorsOnInvalidScheme(t *testing.T) { func TestAccFsCpErrorsOnInvalidScheme(t *testing.T) {
t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV")) t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV"))
_, _, err := RequireErrorRun(t, "fs", "cp", "dbfs:/a", "https:/b") _, _, err := RequireErrorRun(t, "fs", "cp", "dbfs:/a", "https:/b")
assert.Equal(t, "unsupported scheme https specified for path https:/b. Please specify scheme \"dbfs\" or \"file\". Example: file:/foo/bar or file:/c:/foo/bar", err.Error()) assert.Equal(t, "invalid scheme: https", err.Error())
} }
func TestAccFsCpSourceIsDirectoryButTargetIsFile(t *testing.T) { func TestAccFsCpSourceIsDirectoryButTargetIsFile(t *testing.T) {

View File

@ -68,14 +68,14 @@ type DbfsClient struct {
workspaceClient *databricks.WorkspaceClient workspaceClient *databricks.WorkspaceClient
// File operations will be relative to this path. // File operations will be relative to this path.
root RootPath root WorkspaceRootPath
} }
func NewDbfsClient(w *databricks.WorkspaceClient, root string) (Filer, error) { func NewDbfsClient(w *databricks.WorkspaceClient, root string) (Filer, error) {
return &DbfsClient{ return &DbfsClient{
workspaceClient: w, workspaceClient: w,
root: NewRootPath(root), root: NewWorkspaceRootPath(root),
}, nil }, nil
} }

View File

@ -60,7 +60,7 @@ type FilesClient struct {
apiClient *client.DatabricksClient apiClient *client.DatabricksClient
// File operations will be relative to this path. // File operations will be relative to this path.
root RootPath root WorkspaceRootPath
} }
func filesNotImplementedError(fn string) error { func filesNotImplementedError(fn string) error {
@ -77,7 +77,7 @@ func NewFilesClient(w *databricks.WorkspaceClient, root string) (Filer, error) {
workspaceClient: w, workspaceClient: w,
apiClient: apiClient, apiClient: apiClient,
root: NewRootPath(root), root: NewWorkspaceRootPath(root),
}, nil }, nil
} }

View File

@ -13,12 +13,12 @@ import (
// LocalClient implements the [Filer] interface for the local filesystem. // LocalClient implements the [Filer] interface for the local filesystem.
type LocalClient struct { type LocalClient struct {
// File operations will be relative to this path. // File operations will be relative to this path.
root RootPath root localRootPath
} }
func NewLocalClient(root string) (Filer, error) { func NewLocalClient(root string) (Filer, error) {
return &LocalClient{ return &LocalClient{
root: NewRootPath(root), root: NewLocalRootPath(root),
}, nil }, nil
} }

View File

@ -0,0 +1,27 @@
package filer
import (
"fmt"
"path/filepath"
"strings"
)
type localRootPath struct {
rootPath string
}
func NewLocalRootPath(root string) localRootPath {
if root == "" {
return localRootPath{""}
}
return localRootPath{filepath.Clean(root)}
}
func (rp *localRootPath) Join(name string) (string, error) {
absPath := filepath.Join(rp.rootPath, name)
if !strings.HasPrefix(absPath, rp.rootPath) {
return "", fmt.Errorf("relative path escapes root: %s", name)
}
return absPath, nil
}

View File

@ -0,0 +1,142 @@
package filer
import (
"path/filepath"
"runtime"
"testing"
"github.com/stretchr/testify/assert"
)
func testUnixLocalRootPath(t *testing.T, uncleanRoot string) {
cleanRoot := filepath.Clean(uncleanRoot)
rp := NewLocalRootPath(uncleanRoot)
remotePath, err := rp.Join("a/b/c")
assert.NoError(t, err)
assert.Equal(t, cleanRoot+"/a/b/c", remotePath)
remotePath, err = rp.Join("a/b/../d")
assert.NoError(t, err)
assert.Equal(t, cleanRoot+"/a/d", remotePath)
remotePath, err = rp.Join("a/../c")
assert.NoError(t, err)
assert.Equal(t, cleanRoot+"/c", remotePath)
remotePath, err = rp.Join("a/b/c/.")
assert.NoError(t, err)
assert.Equal(t, cleanRoot+"/a/b/c", remotePath)
remotePath, err = rp.Join("a/b/c/d/./../../f/g")
assert.NoError(t, err)
assert.Equal(t, cleanRoot+"/a/b/f/g", remotePath)
remotePath, err = rp.Join(".//a/..//./b/..")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)
remotePath, err = rp.Join("a/b/../..")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)
remotePath, err = rp.Join("")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)
remotePath, err = rp.Join(".")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)
remotePath, err = rp.Join("/")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)
_, err = rp.Join("..")
assert.ErrorContains(t, err, `relative path escapes root: ..`)
_, err = rp.Join("a/../..")
assert.ErrorContains(t, err, `relative path escapes root: a/../..`)
_, err = rp.Join("./../.")
assert.ErrorContains(t, err, `relative path escapes root: ./../.`)
_, err = rp.Join("/./.././..")
assert.ErrorContains(t, err, `relative path escapes root: /./.././..`)
_, err = rp.Join("./../.")
assert.ErrorContains(t, err, `relative path escapes root: ./../.`)
_, err = rp.Join("./..")
assert.ErrorContains(t, err, `relative path escapes root: ./..`)
_, err = rp.Join("./../../..")
assert.ErrorContains(t, err, `relative path escapes root: ./../../..`)
_, err = rp.Join("./../a/./b../../..")
assert.ErrorContains(t, err, `relative path escapes root: ./../a/./b../../..`)
_, err = rp.Join("../..")
assert.ErrorContains(t, err, `relative path escapes root: ../..`)
}
func TestUnixLocalRootPath(t *testing.T) {
if runtime.GOOS != "darwin" && runtime.GOOS != "linux" {
t.SkipNow()
}
testUnixLocalRootPath(t, "/some/root/path")
testUnixLocalRootPath(t, "/some/root/path/")
testUnixLocalRootPath(t, "/some/root/path/.")
testUnixLocalRootPath(t, "/some/root/../path/")
}
func testWindowsLocalRootPath(t *testing.T, uncleanRoot string) {
cleanRoot := filepath.Clean(uncleanRoot)
rp := NewLocalRootPath(uncleanRoot)
remotePath, err := rp.Join(`a\b\c`)
assert.NoError(t, err)
assert.Equal(t, cleanRoot+`\a\b\c`, remotePath)
remotePath, err = rp.Join(`a\b\..\d`)
assert.NoError(t, err)
assert.Equal(t, cleanRoot+`\a\d`, remotePath)
remotePath, err = rp.Join(`a\..\c`)
assert.NoError(t, err)
assert.Equal(t, cleanRoot+`\c`, remotePath)
remotePath, err = rp.Join(`a\b\c\.`)
assert.NoError(t, err)
assert.Equal(t, cleanRoot+`\a\b\c`, remotePath)
remotePath, err = rp.Join(`a\b\..\..`)
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)
remotePath, err = rp.Join("")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)
remotePath, err = rp.Join(".")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)
_, err = rp.Join("..")
assert.ErrorContains(t, err, `relative path escapes root`)
_, err = rp.Join(`a\..\..`)
assert.ErrorContains(t, err, `relative path escapes root`)
}
func TestWindowsLocalRootPath(t *testing.T) {
if runtime.GOOS != "windows" {
t.SkipNow()
}
testWindowsLocalRootPath(t, `c:\some\root\path`)
testWindowsLocalRootPath(t, `c:\some\root\path\`)
testWindowsLocalRootPath(t, `c:\some\root\path\.`)
testWindowsLocalRootPath(t, `C:\some\root\..\path\`)
}

View File

@ -78,7 +78,7 @@ type WorkspaceFilesClient struct {
apiClient *client.DatabricksClient apiClient *client.DatabricksClient
// File operations will be relative to this path. // File operations will be relative to this path.
root RootPath root WorkspaceRootPath
} }
func NewWorkspaceFilesClient(w *databricks.WorkspaceClient, root string) (Filer, error) { func NewWorkspaceFilesClient(w *databricks.WorkspaceClient, root string) (Filer, error) {
@ -91,7 +91,7 @@ func NewWorkspaceFilesClient(w *databricks.WorkspaceClient, root string) (Filer,
workspaceClient: w, workspaceClient: w,
apiClient: apiClient, apiClient: apiClient,
root: NewRootPath(root), root: NewWorkspaceRootPath(root),
}, nil }, nil
} }

View File

@ -6,23 +6,23 @@ import (
"strings" "strings"
) )
// RootPath can be joined with a relative path and ensures that // WorkspaceRootPath can be joined with a relative path and ensures that
// the returned path is always a strict child of the root path. // the returned path is always a strict child of the root path.
type RootPath struct { type WorkspaceRootPath struct {
rootPath string rootPath string
} }
// NewRootPath constructs and returns [RootPath]. // NewWorkspaceRootPath constructs and returns [RootPath].
// The named path is cleaned on construction. // The named path is cleaned on construction.
func NewRootPath(name string) RootPath { func NewWorkspaceRootPath(name string) WorkspaceRootPath {
return RootPath{ return WorkspaceRootPath{
rootPath: path.Clean(name), rootPath: path.Clean(name),
} }
} }
// Join returns the specified path name joined to the root. // Join returns the specified path name joined to the root.
// It returns an error if the resulting path is not a strict child of the root path. // It returns an error if the resulting path is not a strict child of the root path.
func (p *RootPath) Join(name string) (string, error) { func (p *WorkspaceRootPath) Join(name string) (string, error) {
absPath := path.Join(p.rootPath, name) absPath := path.Join(p.rootPath, name)
// Don't allow escaping the specified root using relative paths. // Don't allow escaping the specified root using relative paths.

View File

@ -9,7 +9,7 @@ import (
func testRootPath(t *testing.T, uncleanRoot string) { func testRootPath(t *testing.T, uncleanRoot string) {
cleanRoot := path.Clean(uncleanRoot) cleanRoot := path.Clean(uncleanRoot)
rp := NewRootPath(uncleanRoot) rp := NewWorkspaceRootPath(uncleanRoot)
remotePath, err := rp.Join("a/b/c") remotePath, err := rp.Join("a/b/c")
assert.NoError(t, err) assert.NoError(t, err)