diff --git a/cmd/workspace/workspace/import_dir.go b/cmd/workspace/workspace/import_dir.go new file mode 100644 index 00000000..385deca4 --- /dev/null +++ b/cmd/workspace/workspace/import_dir.go @@ -0,0 +1,84 @@ +package workspace + +import ( + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/sync" + "github.com/databricks/databricks-sdk-go" + "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" +) + +// TODO: add some comments here +var importDirCmd = &cobra.Command{ + Use: "import_dir SOURCE_PATH TARGET_PATH", + Short: `Recursively imports a directory from local to the Databricks workspace.`, + Long: ` + Imports directory to the workspace. + + This command respects your git ignore configuration. Notebooks with extensions + .scala, .py, .sql, .r, .R, .ipynb are stripped of their extensions. +`, + + Annotations: map[string]string{}, + PreRunE: root.MustWorkspaceClient, + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) (err error) { + ctx := cmd.Context() + sourcePath := args[0] + targetPath := args[1] + + // Initialize syncer to do a full sync with the correct from source to target. + // This will upload the local files + opts := sync.SyncOptions{ + LocalPath: sourcePath, + RemotePath: targetPath, + Full: true, + WorkspaceClient: databricks.Must(databricks.NewWorkspaceClient()), + DisallowOverwrites: !importDirOverwrite, + } + s, err := sync.New(ctx, opts) + if err != nil { + return err + } + + // Initialize error wait group, and spawn the progress event emitter inside + // the error wait group + group, ctx := errgroup.WithContext(ctx) + group.Go( + func() error { + return renderSyncEvents(ctx, s.Events(), s) + }) + + // Start Uploading local files + cmdio.Render(ctx, newImportStartedEvent(sourcePath, targetPath)) + err = s.RunOnce(ctx) + if err != nil { + return err + } + // Upload completed, close the syncer + s.Close() + + // Wait for any inflight progress events to be emitted + if err := group.Wait(); err != nil { + return err + } + + // Render import completetion event + cmdio.Render(ctx, newImportCompleteEvent(sourcePath, targetPath)) + return nil + }, +} + +var importDirOverwrite bool + +func init() { + importDirCmd.Annotations["template"] = cmdio.Heredoc(` + {{if eq .Type "IMPORT_STARTED"}}Import started + {{else if eq .Type "UPLOAD_COMPLETE"}}Uploaded {{.SourcePath}} -> {{.TargetPath}} + {{else if eq .Type "IMPORT_COMPLETE"}}Import completed + {{end}} + `) + importDirCmd.Flags().BoolVar(&importDirOverwrite, "overwrite", false, "Overwrite if file already exists in the workspace") + Cmd.AddCommand(importDirCmd) +} diff --git a/cmd/workspace/workspace/import_events.go b/cmd/workspace/workspace/import_events.go new file mode 100644 index 00000000..7627691c --- /dev/null +++ b/cmd/workspace/workspace/import_events.go @@ -0,0 +1,67 @@ +package workspace + +import ( + "context" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/sync" +) + +type fileIOEvent struct { + SourcePath string `json:"source_path,omitempty"` + TargetPath string `json:"target_path,omitempty"` + Type string `json:"type"` +} + +func newImportStartedEvent(sourcePath, targetPath string) fileIOEvent { + return fileIOEvent{ + SourcePath: sourcePath, + TargetPath: targetPath, + Type: "IMPORT_STARTED", + } +} + +func newImportCompleteEvent(sourcePath, targetPath string) fileIOEvent { + return fileIOEvent{ + Type: "IMPORT_COMPLETE", + } +} + +func newUploadCompleteEvent(sourcePath, targetPath string) fileIOEvent { + return fileIOEvent{ + SourcePath: sourcePath, + TargetPath: targetPath, + Type: "UPLOAD_COMPLETE", + } +} + +func renderSyncEvents(ctx context.Context, eventChannel <-chan sync.Event, syncer *sync.Sync) error { + for { + select { + case <-ctx.Done(): + return nil + case e, ok := <-eventChannel: + if !ok { + return nil + } + + // We parse progress events from the sync to track when file uploads + // are complete and emit the corresponding events + if e.String() != "" && e.Type() == sync.EventTypeProgress { + progressEvent := e.(*sync.EventSyncProgress) + if progressEvent.Progress < 1 { + return nil + } + // TODO: test this works with windows paths + remotePath, err := syncer.RemotePath(progressEvent.Path) + if err != nil { + return err + } + err = cmdio.Render(ctx, newUploadCompleteEvent(progressEvent.Path, remotePath)) + if err != nil { + return err + } + } + } + } +} diff --git a/libs/cmdio/render.go b/libs/cmdio/render.go index 8aff2b8d..c5abe2cb 100644 --- a/libs/cmdio/render.go +++ b/libs/cmdio/render.go @@ -49,6 +49,10 @@ func renderJson(w io.Writer, v any) error { return err } _, err = w.Write(pretty) + if err != nil { + return err + } + _, err = w.Write([]byte("\n")) return err } diff --git a/libs/sync/event.go b/libs/sync/event.go index 8e5c0efa..dcba2232 100644 --- a/libs/sync/event.go +++ b/libs/sync/event.go @@ -24,6 +24,7 @@ const ( type Event interface { fmt.Stringer + Type() EventType } type EventBase struct { @@ -73,6 +74,10 @@ func (e *EventStart) String() string { return fmt.Sprintf("Action: %s", e.EventChanges.String()) } +func (e *EventStart) Type() EventType { + return EventTypeStart +} + func newEventStart(seq int, put []string, delete []string) Event { return &EventStart{ EventBase: newEventBase(seq, EventTypeStart), @@ -106,6 +111,10 @@ func (e *EventSyncProgress) String() string { } } +func (e *EventSyncProgress) Type() EventType { + return EventTypeProgress +} + func newEventProgress(seq int, action EventAction, path string, progress float32) Event { return &EventSyncProgress{ EventBase: newEventBase(seq, EventTypeProgress), @@ -133,6 +142,10 @@ func (e *EventSyncComplete) String() string { return "Complete" } +func (e *EventSyncComplete) Type() EventType { + return EventTypeComplete +} + func newEventComplete(seq int, put []string, delete []string) Event { return &EventSyncComplete{ EventBase: newEventBase(seq, EventTypeComplete), diff --git a/libs/sync/repofiles/repofiles.go b/libs/sync/repofiles/repofiles.go index 8fcabc11..97710217 100644 --- a/libs/sync/repofiles/repofiles.go +++ b/libs/sync/repofiles/repofiles.go @@ -17,23 +17,30 @@ import ( "github.com/databricks/databricks-sdk-go/service/workspace" ) +type RepoFileOptions struct { + OverwriteIfExists bool +} + // RepoFiles wraps reading and writing into a remote repo with safeguards to prevent // accidental deletion of repos and more robust methods to overwrite workspace files type RepoFiles struct { + *RepoFileOptions + repoRoot string localRoot string workspaceClient *databricks.WorkspaceClient } -func Create(repoRoot, localRoot string, workspaceClient *databricks.WorkspaceClient) *RepoFiles { +func Create(repoRoot, localRoot string, workspaceClient *databricks.WorkspaceClient, opts *RepoFileOptions) *RepoFiles { return &RepoFiles{ repoRoot: repoRoot, localRoot: localRoot, workspaceClient: workspaceClient, + RepoFileOptions: opts, } } -func (r *RepoFiles) remotePath(relativePath string) (string, error) { +func (r *RepoFiles) RemotePath(relativePath string) (string, error) { fullPath := path.Join(r.repoRoot, relativePath) cleanFullPath := path.Clean(fullPath) if !strings.HasPrefix(cleanFullPath, r.repoRoot) { @@ -58,12 +65,12 @@ func (r *RepoFiles) writeRemote(ctx context.Context, relativePath string, conten if err != nil { return err } - remotePath, err := r.remotePath(relativePath) + remotePath, err := r.RemotePath(relativePath) if err != nil { return err } escapedPath := url.PathEscape(strings.TrimLeft(remotePath, "/")) - apiPath := fmt.Sprintf("/api/2.0/workspace-files/import-file/%s?overwrite=true", escapedPath) + apiPath := fmt.Sprintf("/api/2.0/workspace-files/import-file/%s?overwrite=%t", escapedPath, r.OverwriteIfExists) err = apiClient.Do(ctx, http.MethodPost, apiPath, content, nil) @@ -113,7 +120,7 @@ func (r *RepoFiles) writeRemote(ctx context.Context, relativePath string, conten } func (r *RepoFiles) deleteRemote(ctx context.Context, relativePath string) error { - remotePath, err := r.remotePath(relativePath) + remotePath, err := r.RemotePath(relativePath) if err != nil { return err } diff --git a/libs/sync/repofiles/repofiles_test.go b/libs/sync/repofiles/repofiles_test.go index 2a881d90..e71f26ab 100644 --- a/libs/sync/repofiles/repofiles_test.go +++ b/libs/sync/repofiles/repofiles_test.go @@ -10,68 +10,68 @@ import ( func TestRepoFilesRemotePath(t *testing.T) { repoRoot := "/Repos/doraemon/bar" - repoFiles := Create(repoRoot, "/doraemon/foo/bar", nil) + repoFiles := Create(repoRoot, "/doraemon/foo/bar", nil, nil) - remotePath, err := repoFiles.remotePath("a/b/c") + remotePath, err := repoFiles.RemotePath("a/b/c") assert.NoError(t, err) assert.Equal(t, repoRoot+"/a/b/c", remotePath) - remotePath, err = repoFiles.remotePath("a/b/../d") + remotePath, err = repoFiles.RemotePath("a/b/../d") assert.NoError(t, err) assert.Equal(t, repoRoot+"/a/d", remotePath) - remotePath, err = repoFiles.remotePath("a/../c") + remotePath, err = repoFiles.RemotePath("a/../c") assert.NoError(t, err) assert.Equal(t, repoRoot+"/c", remotePath) - remotePath, err = repoFiles.remotePath("a/b/c/.") + remotePath, err = repoFiles.RemotePath("a/b/c/.") assert.NoError(t, err) assert.Equal(t, repoRoot+"/a/b/c", remotePath) - remotePath, err = repoFiles.remotePath("a/b/c/d/./../../f/g") + remotePath, err = repoFiles.RemotePath("a/b/c/d/./../../f/g") assert.NoError(t, err) assert.Equal(t, repoRoot+"/a/b/f/g", remotePath) - _, err = repoFiles.remotePath("..") + _, err = repoFiles.RemotePath("..") assert.ErrorContains(t, err, `relative file path is not inside repo root: ..`) - _, err = repoFiles.remotePath("a/../..") + _, err = repoFiles.RemotePath("a/../..") assert.ErrorContains(t, err, `relative file path is not inside repo root: a/../..`) - _, err = repoFiles.remotePath("./../.") + _, err = repoFiles.RemotePath("./../.") assert.ErrorContains(t, err, `relative file path is not inside repo root: ./../.`) - _, err = repoFiles.remotePath("/./.././..") + _, err = repoFiles.RemotePath("/./.././..") assert.ErrorContains(t, err, `relative file path is not inside repo root: /./.././..`) - _, err = repoFiles.remotePath("./../.") + _, err = repoFiles.RemotePath("./../.") assert.ErrorContains(t, err, `relative file path is not inside repo root: ./../.`) - _, err = repoFiles.remotePath("./..") + _, err = repoFiles.RemotePath("./..") assert.ErrorContains(t, err, `relative file path is not inside repo root: ./..`) - _, err = repoFiles.remotePath("./../../..") + _, err = repoFiles.RemotePath("./../../..") assert.ErrorContains(t, err, `relative file path is not inside repo root: ./../../..`) - _, err = repoFiles.remotePath("./../a/./b../../..") + _, err = repoFiles.RemotePath("./../a/./b../../..") assert.ErrorContains(t, err, `relative file path is not inside repo root: ./../a/./b../../..`) - _, err = repoFiles.remotePath("../..") + _, err = repoFiles.RemotePath("../..") assert.ErrorContains(t, err, `relative file path is not inside repo root: ../..`) - _, err = repoFiles.remotePath(".//a/..//./b/..") + _, err = repoFiles.RemotePath(".//a/..//./b/..") assert.ErrorContains(t, err, `file path relative to repo root cannot be empty`) - _, err = repoFiles.remotePath("a/b/../..") + _, err = repoFiles.RemotePath("a/b/../..") assert.ErrorContains(t, err, "file path relative to repo root cannot be empty") - _, err = repoFiles.remotePath("") + _, err = repoFiles.RemotePath("") assert.ErrorContains(t, err, "file path relative to repo root cannot be empty") - _, err = repoFiles.remotePath(".") + _, err = repoFiles.RemotePath(".") assert.ErrorContains(t, err, "file path relative to repo root cannot be empty") - _, err = repoFiles.remotePath("/") + _, err = repoFiles.RemotePath("/") assert.ErrorContains(t, err, "file path relative to repo root cannot be empty") } @@ -81,7 +81,7 @@ func TestRepoReadLocal(t *testing.T) { err := os.WriteFile(helloPath, []byte("my name is doraemon :P"), os.ModePerm) assert.NoError(t, err) - repoFiles := Create("/Repos/doraemon/bar", tempDir, nil) + repoFiles := Create("/Repos/doraemon/bar", tempDir, nil, nil) bytes, err := repoFiles.readLocal("./a/../hello.txt") assert.NoError(t, err) assert.Equal(t, "my name is doraemon :P", string(bytes)) diff --git a/libs/sync/sync.go b/libs/sync/sync.go index 54d0624e..e19340fb 100644 --- a/libs/sync/sync.go +++ b/libs/sync/sync.go @@ -24,6 +24,10 @@ type SyncOptions struct { WorkspaceClient *databricks.WorkspaceClient Host string + + // If set, sync will not be able to overwrite any existing paths on the + // workspace file system. + DisallowOverwrites bool } type Sync struct { @@ -76,8 +80,9 @@ func New(ctx context.Context, opts SyncOptions) (*Sync, error) { return nil, fmt.Errorf("unable to load sync snapshot: %w", err) } } - - repoFiles := repofiles.Create(opts.RemotePath, opts.LocalPath, opts.WorkspaceClient) + repoFiles := repofiles.Create(opts.RemotePath, opts.LocalPath, opts.WorkspaceClient, &repofiles.RepoFileOptions{ + OverwriteIfExists: !opts.DisallowOverwrites, + }) return &Sync{ SyncOptions: &opts, @@ -125,6 +130,14 @@ func (s *Sync) notifyComplete(ctx context.Context, d diff) { s.seq++ } +func (s *Sync) RemotePath(localPath string) (string, error) { + relativePath, ok := s.snapshot.LocalToRemoteNames[localPath] + if !ok { + return "", fmt.Errorf("could not find remote path for %s", localPath) + } + return s.repoFiles.RemotePath(relativePath) +} + func (s *Sync) RunOnce(ctx context.Context) error { // tradeoff: doing portable monitoring only due to macOS max descriptor manual ulimit setting requirement // https://github.com/gorakhargosh/watchdog/blob/master/src/watchdog/observers/kqueue.py#L394-L418