diff --git a/cmd/sync/sync.go b/cmd/sync/sync.go index 62d834e3..ea268be1 100644 --- a/cmd/sync/sync.go +++ b/cmd/sync/sync.go @@ -1,95 +1,17 @@ package sync import ( - "context" "fmt" "log" - "strings" "time" "github.com/databricks/bricks/cmd/root" "github.com/databricks/bricks/git" "github.com/databricks/bricks/libs/sync" "github.com/databricks/bricks/project" - "github.com/databricks/databricks-sdk-go" - "github.com/databricks/databricks-sdk-go/apierr" - "github.com/databricks/databricks-sdk-go/service/scim" - "github.com/databricks/databricks-sdk-go/service/workspace" "github.com/spf13/cobra" ) -func matchesBasePaths(me *scim.User, path string) error { - basePaths := []string{ - fmt.Sprintf("/Users/%s/", me.UserName), - fmt.Sprintf("/Repos/%s/", me.UserName), - } - basePathMatch := false - for _, basePath := range basePaths { - if strings.HasPrefix(path, basePath) { - basePathMatch = true - break - } - } - if !basePathMatch { - return fmt.Errorf("path must be nested under %s or %s", basePaths[0], basePaths[1]) - } - return nil -} - -// ensureRemotePathIsUsable checks if the specified path is nested under -// expected base paths and if it is a directory or repository. -func ensureRemotePathIsUsable(ctx context.Context, wsc *databricks.WorkspaceClient, me *scim.User, path string) error { - err := matchesBasePaths(me, path) - if err != nil { - return err - } - - // Ensure that the remote path exits. - // If it is a repo, it has to exist. - // If it is a workspace path, it may not exist. - info, err := wsc.Workspace.GetStatusByPath(ctx, path) - if err != nil { - // We only deal with 404s below. - if !apierr.IsMissing(err) { - return err - } - - switch { - case strings.HasPrefix(path, "/Repos/"): - return fmt.Errorf("%s does not exist; please create it first", path) - case strings.HasPrefix(path, "/Users/"): - // The workspace path doesn't exist. Create it and try again. - err = wsc.Workspace.MkdirsByPath(ctx, path) - if err != nil { - return fmt.Errorf("unable to create directory at %s: %w", path, err) - } - info, err = wsc.Workspace.GetStatusByPath(ctx, path) - if err != nil { - return err - } - default: - return err - } - } - - log.Printf( - "[DEBUG] Path %s has type %s (ID: %d)", - info.Path, - strings.ToLower(info.ObjectType.String()), - info.ObjectId, - ) - - // We expect the object at path to be a directory or a repo. - switch info.ObjectType { - case workspace.ObjectTypeDirectory: - return nil - case workspace.ObjectTypeRepo: - return nil - } - - return fmt.Errorf("%s points to a %s", path, strings.ToLower(info.ObjectType.String())) -} - // syncCmd represents the sync command var syncCmd = &cobra.Command{ Use: "sync", @@ -115,10 +37,6 @@ var syncCmd = &cobra.Command{ } log.Printf("[INFO] Remote file sync location: %v", *remotePath) - err = ensureRemotePathIsUsable(ctx, wsc, me, *remotePath) - if err != nil { - return err - } cacheDir, err := prj.CacheDir() if err != nil { @@ -134,7 +52,7 @@ var syncCmd = &cobra.Command{ WorkspaceClient: wsc, } - s, err := sync.New(opts) + s, err := sync.New(ctx, opts) if err != nil { return err } diff --git a/libs/sync/path.go b/libs/sync/path.go new file mode 100644 index 00000000..6a87316f --- /dev/null +++ b/libs/sync/path.go @@ -0,0 +1,109 @@ +package sync + +import ( + "context" + "fmt" + "log" + "path" + "strings" + + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/service/scim" + "github.com/databricks/databricks-sdk-go/service/workspace" +) + +// Return if the child path is nested under the parent path. +func isPathNestedUnder(child, parent string) bool { + child = path.Clean(child) + parent = path.Clean(parent) + + // Traverse up the tree as long as "child" is contained in "parent". + for len(child) > len(parent) && strings.HasPrefix(child, parent) { + child = path.Dir(child) + if child == parent { + return true + } + } + return false +} + +// Check if the specified path is nested under one of the allowed base paths. +func checkPathNestedUnderBasePaths(me *scim.User, p string) error { + validBasePaths := []string{ + path.Clean(fmt.Sprintf("/Users/%s", me.UserName)), + path.Clean(fmt.Sprintf("/Repos/%s", me.UserName)), + } + + givenBasePath := path.Clean(p) + match := false + for _, basePath := range validBasePaths { + if isPathNestedUnder(givenBasePath, basePath) { + match = true + break + } + } + if !match { + return fmt.Errorf("path must be nested under %s", strings.Join(validBasePaths, " or ")) + } + return nil +} + +// ensureRemotePathIsUsable checks if the specified path is nested under +// expected base paths and if it is a directory or repository. +func ensureRemotePathIsUsable(ctx context.Context, wsc *databricks.WorkspaceClient, path string) error { + me, err := wsc.CurrentUser.Me(ctx) + if err != nil { + return err + } + + err = checkPathNestedUnderBasePaths(me, path) + if err != nil { + return err + } + + // Ensure that the remote path exists. + // If it is a repo, it has to exist. + // If it is a workspace path, it may not exist. + info, err := wsc.Workspace.GetStatusByPath(ctx, path) + if err != nil { + // We only deal with 404s below. + if !apierr.IsMissing(err) { + return err + } + + switch { + case strings.HasPrefix(path, "/Repos/"): + return fmt.Errorf("%s does not exist; please create it first", path) + case strings.HasPrefix(path, "/Users/"): + // The workspace path doesn't exist. Create it and try again. + err = wsc.Workspace.MkdirsByPath(ctx, path) + if err != nil { + return fmt.Errorf("unable to create directory at %s: %w", path, err) + } + info, err = wsc.Workspace.GetStatusByPath(ctx, path) + if err != nil { + return err + } + default: + return err + } + } + + log.Printf( + "[DEBUG] Path %s has type %s (ID: %d)", + info.Path, + strings.ToLower(info.ObjectType.String()), + info.ObjectId, + ) + + // We expect the object at path to be a directory or a repo. + switch info.ObjectType { + case workspace.ObjectTypeDirectory: + return nil + case workspace.ObjectTypeRepo: + return nil + } + + return fmt.Errorf("%s points to a %s", path, strings.ToLower(info.ObjectType.String())) +} diff --git a/libs/sync/path_test.go b/libs/sync/path_test.go new file mode 100644 index 00000000..0671b634 --- /dev/null +++ b/libs/sync/path_test.go @@ -0,0 +1,39 @@ +package sync + +import ( + "testing" + + "github.com/databricks/databricks-sdk-go/service/scim" + "github.com/stretchr/testify/assert" +) + +func TestPathNestedUnderBasePaths(t *testing.T) { + me := scim.User{ + UserName: "jane@doe.com", + } + + // Not nested under allowed base paths. + assert.Error(t, checkPathNestedUnderBasePaths(&me, "/Repos/jane@doe.com")) + assert.Error(t, checkPathNestedUnderBasePaths(&me, "/Repos/jane@doe.com/.")) + assert.Error(t, checkPathNestedUnderBasePaths(&me, "/Repos/jane@doe.com/..")) + assert.Error(t, checkPathNestedUnderBasePaths(&me, "/Repos/john@doe.com")) + assert.Error(t, checkPathNestedUnderBasePaths(&me, "/Repos/jane@doe.comsuffix/foo")) + assert.Error(t, checkPathNestedUnderBasePaths(&me, "/Repos/")) + assert.Error(t, checkPathNestedUnderBasePaths(&me, "/Repos")) + assert.Error(t, checkPathNestedUnderBasePaths(&me, "/Users/jane@doe.com")) + assert.Error(t, checkPathNestedUnderBasePaths(&me, "/Users/jane@doe.com/.")) + assert.Error(t, checkPathNestedUnderBasePaths(&me, "/Users/jane@doe.com/..")) + assert.Error(t, checkPathNestedUnderBasePaths(&me, "/Users/john@doe.com")) + assert.Error(t, checkPathNestedUnderBasePaths(&me, "/Users/jane@doe.comsuffix/foo")) + assert.Error(t, checkPathNestedUnderBasePaths(&me, "/Users/")) + assert.Error(t, checkPathNestedUnderBasePaths(&me, "/Users")) + assert.Error(t, checkPathNestedUnderBasePaths(&me, "/")) + + // Nested under allowed base paths. + assert.NoError(t, checkPathNestedUnderBasePaths(&me, "/Repos/jane@doe.com/foo")) + assert.NoError(t, checkPathNestedUnderBasePaths(&me, "/Repos/jane@doe.com/./foo")) + assert.NoError(t, checkPathNestedUnderBasePaths(&me, "/Repos/jane@doe.com/foo/bar/qux")) + assert.NoError(t, checkPathNestedUnderBasePaths(&me, "/Users/jane@doe.com/foo")) + assert.NoError(t, checkPathNestedUnderBasePaths(&me, "/Users/jane@doe.com/./foo")) + assert.NoError(t, checkPathNestedUnderBasePaths(&me, "/Users/jane@doe.com/foo/bar/qux")) +} diff --git a/libs/sync/sync.go b/libs/sync/sync.go index 55055f8a..ec67c6cc 100644 --- a/libs/sync/sync.go +++ b/libs/sync/sync.go @@ -32,13 +32,19 @@ type Sync struct { } // New initializes and returns a new [Sync] instance. -func New(opts SyncOptions) (*Sync, error) { +func New(ctx context.Context, opts SyncOptions) (*Sync, error) { fileSet := git.NewFileSet(opts.LocalPath) err := fileSet.EnsureValidGitIgnoreExists() if err != nil { return nil, err } + // Verify that the remote path we're about to synchronize to is valid and allowed. + err = ensureRemotePathIsUsable(ctx, opts.WorkspaceClient, opts.RemotePath) + if err != nil { + return nil, err + } + // TODO: The host may be late-initialized in certain Azure setups where we // specify the workspace by its resource ID. tracked in: https://databricks.atlassian.net/browse/DECO-194 opts.Host = opts.WorkspaceClient.Config.Host