Add workspace import_dir command to the CLI

This commit is contained in:
Shreyas Goenka 2023-06-01 11:23:31 +02:00
parent 92cb52041d
commit 9f4bf1261b
No known key found for this signature in database
GPG Key ID: 92A07DF49CCB0622
7 changed files with 216 additions and 28 deletions

View File

@ -0,0 +1,84 @@
package workspace
import (
"github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/sync"
"github.com/databricks/databricks-sdk-go"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
)
// TODO: add some comments here
var importDirCmd = &cobra.Command{
Use: "import_dir SOURCE_PATH TARGET_PATH",
Short: `Recursively imports a directory from local to the Databricks workspace.`,
Long: `
Imports directory to the workspace.
This command respects your git ignore configuration. Notebooks with extensions
.scala, .py, .sql, .r, .R, .ipynb are stripped of their extensions.
`,
Annotations: map[string]string{},
PreRunE: root.MustWorkspaceClient,
Args: cobra.ExactArgs(2),
RunE: func(cmd *cobra.Command, args []string) (err error) {
ctx := cmd.Context()
sourcePath := args[0]
targetPath := args[1]
// Initialize syncer to do a full sync with the correct from source to target.
// This will upload the local files
opts := sync.SyncOptions{
LocalPath: sourcePath,
RemotePath: targetPath,
Full: true,
WorkspaceClient: databricks.Must(databricks.NewWorkspaceClient()),
DisallowOverwrites: !importDirOverwrite,
}
s, err := sync.New(ctx, opts)
if err != nil {
return err
}
// Initialize error wait group, and spawn the progress event emitter inside
// the error wait group
group, ctx := errgroup.WithContext(ctx)
group.Go(
func() error {
return renderSyncEvents(ctx, s.Events(), s)
})
// Start Uploading local files
cmdio.Render(ctx, newImportStartedEvent(sourcePath, targetPath))
err = s.RunOnce(ctx)
if err != nil {
return err
}
// Upload completed, close the syncer
s.Close()
// Wait for any inflight progress events to be emitted
if err := group.Wait(); err != nil {
return err
}
// Render import completetion event
cmdio.Render(ctx, newImportCompleteEvent(sourcePath, targetPath))
return nil
},
}
var importDirOverwrite bool
func init() {
importDirCmd.Annotations["template"] = cmdio.Heredoc(`
{{if eq .Type "IMPORT_STARTED"}}Import started
{{else if eq .Type "UPLOAD_COMPLETE"}}Uploaded {{.SourcePath}} -> {{.TargetPath}}
{{else if eq .Type "IMPORT_COMPLETE"}}Import completed
{{end}}
`)
importDirCmd.Flags().BoolVar(&importDirOverwrite, "overwrite", false, "Overwrite if file already exists in the workspace")
Cmd.AddCommand(importDirCmd)
}

View File

@ -0,0 +1,67 @@
package workspace
import (
"context"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/sync"
)
type fileIOEvent struct {
SourcePath string `json:"source_path,omitempty"`
TargetPath string `json:"target_path,omitempty"`
Type string `json:"type"`
}
func newImportStartedEvent(sourcePath, targetPath string) fileIOEvent {
return fileIOEvent{
SourcePath: sourcePath,
TargetPath: targetPath,
Type: "IMPORT_STARTED",
}
}
func newImportCompleteEvent(sourcePath, targetPath string) fileIOEvent {
return fileIOEvent{
Type: "IMPORT_COMPLETE",
}
}
func newUploadCompleteEvent(sourcePath, targetPath string) fileIOEvent {
return fileIOEvent{
SourcePath: sourcePath,
TargetPath: targetPath,
Type: "UPLOAD_COMPLETE",
}
}
func renderSyncEvents(ctx context.Context, eventChannel <-chan sync.Event, syncer *sync.Sync) error {
for {
select {
case <-ctx.Done():
return nil
case e, ok := <-eventChannel:
if !ok {
return nil
}
// We parse progress events from the sync to track when file uploads
// are complete and emit the corresponding events
if e.String() != "" && e.Type() == sync.EventTypeProgress {
progressEvent := e.(*sync.EventSyncProgress)
if progressEvent.Progress < 1 {
return nil
}
// TODO: test this works with windows paths
remotePath, err := syncer.RemotePath(progressEvent.Path)
if err != nil {
return err
}
err = cmdio.Render(ctx, newUploadCompleteEvent(progressEvent.Path, remotePath))
if err != nil {
return err
}
}
}
}
}

View File

@ -49,6 +49,10 @@ func renderJson(w io.Writer, v any) error {
return err
}
_, err = w.Write(pretty)
if err != nil {
return err
}
_, err = w.Write([]byte("\n"))
return err
}

View File

@ -24,6 +24,7 @@ const (
type Event interface {
fmt.Stringer
Type() EventType
}
type EventBase struct {
@ -73,6 +74,10 @@ func (e *EventStart) String() string {
return fmt.Sprintf("Action: %s", e.EventChanges.String())
}
func (e *EventStart) Type() EventType {
return EventTypeStart
}
func newEventStart(seq int, put []string, delete []string) Event {
return &EventStart{
EventBase: newEventBase(seq, EventTypeStart),
@ -106,6 +111,10 @@ func (e *EventSyncProgress) String() string {
}
}
func (e *EventSyncProgress) Type() EventType {
return EventTypeProgress
}
func newEventProgress(seq int, action EventAction, path string, progress float32) Event {
return &EventSyncProgress{
EventBase: newEventBase(seq, EventTypeProgress),
@ -133,6 +142,10 @@ func (e *EventSyncComplete) String() string {
return "Complete"
}
func (e *EventSyncComplete) Type() EventType {
return EventTypeComplete
}
func newEventComplete(seq int, put []string, delete []string) Event {
return &EventSyncComplete{
EventBase: newEventBase(seq, EventTypeComplete),

View File

@ -17,23 +17,30 @@ import (
"github.com/databricks/databricks-sdk-go/service/workspace"
)
type RepoFileOptions struct {
OverwriteIfExists bool
}
// RepoFiles wraps reading and writing into a remote repo with safeguards to prevent
// accidental deletion of repos and more robust methods to overwrite workspace files
type RepoFiles struct {
*RepoFileOptions
repoRoot string
localRoot string
workspaceClient *databricks.WorkspaceClient
}
func Create(repoRoot, localRoot string, workspaceClient *databricks.WorkspaceClient) *RepoFiles {
func Create(repoRoot, localRoot string, workspaceClient *databricks.WorkspaceClient, opts *RepoFileOptions) *RepoFiles {
return &RepoFiles{
repoRoot: repoRoot,
localRoot: localRoot,
workspaceClient: workspaceClient,
RepoFileOptions: opts,
}
}
func (r *RepoFiles) remotePath(relativePath string) (string, error) {
func (r *RepoFiles) RemotePath(relativePath string) (string, error) {
fullPath := path.Join(r.repoRoot, relativePath)
cleanFullPath := path.Clean(fullPath)
if !strings.HasPrefix(cleanFullPath, r.repoRoot) {
@ -58,12 +65,12 @@ func (r *RepoFiles) writeRemote(ctx context.Context, relativePath string, conten
if err != nil {
return err
}
remotePath, err := r.remotePath(relativePath)
remotePath, err := r.RemotePath(relativePath)
if err != nil {
return err
}
escapedPath := url.PathEscape(strings.TrimLeft(remotePath, "/"))
apiPath := fmt.Sprintf("/api/2.0/workspace-files/import-file/%s?overwrite=true", escapedPath)
apiPath := fmt.Sprintf("/api/2.0/workspace-files/import-file/%s?overwrite=%t", escapedPath, r.OverwriteIfExists)
err = apiClient.Do(ctx, http.MethodPost, apiPath, content, nil)
@ -113,7 +120,7 @@ func (r *RepoFiles) writeRemote(ctx context.Context, relativePath string, conten
}
func (r *RepoFiles) deleteRemote(ctx context.Context, relativePath string) error {
remotePath, err := r.remotePath(relativePath)
remotePath, err := r.RemotePath(relativePath)
if err != nil {
return err
}

View File

@ -10,68 +10,68 @@ import (
func TestRepoFilesRemotePath(t *testing.T) {
repoRoot := "/Repos/doraemon/bar"
repoFiles := Create(repoRoot, "/doraemon/foo/bar", nil)
repoFiles := Create(repoRoot, "/doraemon/foo/bar", nil, nil)
remotePath, err := repoFiles.remotePath("a/b/c")
remotePath, err := repoFiles.RemotePath("a/b/c")
assert.NoError(t, err)
assert.Equal(t, repoRoot+"/a/b/c", remotePath)
remotePath, err = repoFiles.remotePath("a/b/../d")
remotePath, err = repoFiles.RemotePath("a/b/../d")
assert.NoError(t, err)
assert.Equal(t, repoRoot+"/a/d", remotePath)
remotePath, err = repoFiles.remotePath("a/../c")
remotePath, err = repoFiles.RemotePath("a/../c")
assert.NoError(t, err)
assert.Equal(t, repoRoot+"/c", remotePath)
remotePath, err = repoFiles.remotePath("a/b/c/.")
remotePath, err = repoFiles.RemotePath("a/b/c/.")
assert.NoError(t, err)
assert.Equal(t, repoRoot+"/a/b/c", remotePath)
remotePath, err = repoFiles.remotePath("a/b/c/d/./../../f/g")
remotePath, err = repoFiles.RemotePath("a/b/c/d/./../../f/g")
assert.NoError(t, err)
assert.Equal(t, repoRoot+"/a/b/f/g", remotePath)
_, err = repoFiles.remotePath("..")
_, err = repoFiles.RemotePath("..")
assert.ErrorContains(t, err, `relative file path is not inside repo root: ..`)
_, err = repoFiles.remotePath("a/../..")
_, err = repoFiles.RemotePath("a/../..")
assert.ErrorContains(t, err, `relative file path is not inside repo root: a/../..`)
_, err = repoFiles.remotePath("./../.")
_, err = repoFiles.RemotePath("./../.")
assert.ErrorContains(t, err, `relative file path is not inside repo root: ./../.`)
_, err = repoFiles.remotePath("/./.././..")
_, err = repoFiles.RemotePath("/./.././..")
assert.ErrorContains(t, err, `relative file path is not inside repo root: /./.././..`)
_, err = repoFiles.remotePath("./../.")
_, err = repoFiles.RemotePath("./../.")
assert.ErrorContains(t, err, `relative file path is not inside repo root: ./../.`)
_, err = repoFiles.remotePath("./..")
_, err = repoFiles.RemotePath("./..")
assert.ErrorContains(t, err, `relative file path is not inside repo root: ./..`)
_, err = repoFiles.remotePath("./../../..")
_, err = repoFiles.RemotePath("./../../..")
assert.ErrorContains(t, err, `relative file path is not inside repo root: ./../../..`)
_, err = repoFiles.remotePath("./../a/./b../../..")
_, err = repoFiles.RemotePath("./../a/./b../../..")
assert.ErrorContains(t, err, `relative file path is not inside repo root: ./../a/./b../../..`)
_, err = repoFiles.remotePath("../..")
_, err = repoFiles.RemotePath("../..")
assert.ErrorContains(t, err, `relative file path is not inside repo root: ../..`)
_, err = repoFiles.remotePath(".//a/..//./b/..")
_, err = repoFiles.RemotePath(".//a/..//./b/..")
assert.ErrorContains(t, err, `file path relative to repo root cannot be empty`)
_, err = repoFiles.remotePath("a/b/../..")
_, err = repoFiles.RemotePath("a/b/../..")
assert.ErrorContains(t, err, "file path relative to repo root cannot be empty")
_, err = repoFiles.remotePath("")
_, err = repoFiles.RemotePath("")
assert.ErrorContains(t, err, "file path relative to repo root cannot be empty")
_, err = repoFiles.remotePath(".")
_, err = repoFiles.RemotePath(".")
assert.ErrorContains(t, err, "file path relative to repo root cannot be empty")
_, err = repoFiles.remotePath("/")
_, err = repoFiles.RemotePath("/")
assert.ErrorContains(t, err, "file path relative to repo root cannot be empty")
}
@ -81,7 +81,7 @@ func TestRepoReadLocal(t *testing.T) {
err := os.WriteFile(helloPath, []byte("my name is doraemon :P"), os.ModePerm)
assert.NoError(t, err)
repoFiles := Create("/Repos/doraemon/bar", tempDir, nil)
repoFiles := Create("/Repos/doraemon/bar", tempDir, nil, nil)
bytes, err := repoFiles.readLocal("./a/../hello.txt")
assert.NoError(t, err)
assert.Equal(t, "my name is doraemon :P", string(bytes))

View File

@ -24,6 +24,10 @@ type SyncOptions struct {
WorkspaceClient *databricks.WorkspaceClient
Host string
// If set, sync will not be able to overwrite any existing paths on the
// workspace file system.
DisallowOverwrites bool
}
type Sync struct {
@ -76,8 +80,9 @@ func New(ctx context.Context, opts SyncOptions) (*Sync, error) {
return nil, fmt.Errorf("unable to load sync snapshot: %w", err)
}
}
repoFiles := repofiles.Create(opts.RemotePath, opts.LocalPath, opts.WorkspaceClient)
repoFiles := repofiles.Create(opts.RemotePath, opts.LocalPath, opts.WorkspaceClient, &repofiles.RepoFileOptions{
OverwriteIfExists: !opts.DisallowOverwrites,
})
return &Sync{
SyncOptions: &opts,
@ -125,6 +130,14 @@ func (s *Sync) notifyComplete(ctx context.Context, d diff) {
s.seq++
}
func (s *Sync) RemotePath(localPath string) (string, error) {
relativePath, ok := s.snapshot.LocalToRemoteNames[localPath]
if !ok {
return "", fmt.Errorf("could not find remote path for %s", localPath)
}
return s.repoFiles.RemotePath(relativePath)
}
func (s *Sync) RunOnce(ctx context.Context) error {
// tradeoff: doing portable monitoring only due to macOS max descriptor manual ulimit setting requirement
// https://github.com/gorakhargosh/watchdog/blob/master/src/watchdog/observers/kqueue.py#L394-L418