mirror of https://github.com/databricks/cli.git
Add workspace import_dir command to the CLI
This commit is contained in:
parent
92cb52041d
commit
9f4bf1261b
|
@ -0,0 +1,84 @@
|
|||
package workspace
|
||||
|
||||
import (
|
||||
"github.com/databricks/cli/cmd/root"
|
||||
"github.com/databricks/cli/libs/cmdio"
|
||||
"github.com/databricks/cli/libs/sync"
|
||||
"github.com/databricks/databricks-sdk-go"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// TODO: add some comments here
|
||||
var importDirCmd = &cobra.Command{
|
||||
Use: "import_dir SOURCE_PATH TARGET_PATH",
|
||||
Short: `Recursively imports a directory from local to the Databricks workspace.`,
|
||||
Long: `
|
||||
Imports directory to the workspace.
|
||||
|
||||
This command respects your git ignore configuration. Notebooks with extensions
|
||||
.scala, .py, .sql, .r, .R, .ipynb are stripped of their extensions.
|
||||
`,
|
||||
|
||||
Annotations: map[string]string{},
|
||||
PreRunE: root.MustWorkspaceClient,
|
||||
Args: cobra.ExactArgs(2),
|
||||
RunE: func(cmd *cobra.Command, args []string) (err error) {
|
||||
ctx := cmd.Context()
|
||||
sourcePath := args[0]
|
||||
targetPath := args[1]
|
||||
|
||||
// Initialize syncer to do a full sync with the correct from source to target.
|
||||
// This will upload the local files
|
||||
opts := sync.SyncOptions{
|
||||
LocalPath: sourcePath,
|
||||
RemotePath: targetPath,
|
||||
Full: true,
|
||||
WorkspaceClient: databricks.Must(databricks.NewWorkspaceClient()),
|
||||
DisallowOverwrites: !importDirOverwrite,
|
||||
}
|
||||
s, err := sync.New(ctx, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize error wait group, and spawn the progress event emitter inside
|
||||
// the error wait group
|
||||
group, ctx := errgroup.WithContext(ctx)
|
||||
group.Go(
|
||||
func() error {
|
||||
return renderSyncEvents(ctx, s.Events(), s)
|
||||
})
|
||||
|
||||
// Start Uploading local files
|
||||
cmdio.Render(ctx, newImportStartedEvent(sourcePath, targetPath))
|
||||
err = s.RunOnce(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Upload completed, close the syncer
|
||||
s.Close()
|
||||
|
||||
// Wait for any inflight progress events to be emitted
|
||||
if err := group.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Render import completetion event
|
||||
cmdio.Render(ctx, newImportCompleteEvent(sourcePath, targetPath))
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var importDirOverwrite bool
|
||||
|
||||
func init() {
|
||||
importDirCmd.Annotations["template"] = cmdio.Heredoc(`
|
||||
{{if eq .Type "IMPORT_STARTED"}}Import started
|
||||
{{else if eq .Type "UPLOAD_COMPLETE"}}Uploaded {{.SourcePath}} -> {{.TargetPath}}
|
||||
{{else if eq .Type "IMPORT_COMPLETE"}}Import completed
|
||||
{{end}}
|
||||
`)
|
||||
importDirCmd.Flags().BoolVar(&importDirOverwrite, "overwrite", false, "Overwrite if file already exists in the workspace")
|
||||
Cmd.AddCommand(importDirCmd)
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
package workspace
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/databricks/cli/libs/cmdio"
|
||||
"github.com/databricks/cli/libs/sync"
|
||||
)
|
||||
|
||||
type fileIOEvent struct {
|
||||
SourcePath string `json:"source_path,omitempty"`
|
||||
TargetPath string `json:"target_path,omitempty"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
func newImportStartedEvent(sourcePath, targetPath string) fileIOEvent {
|
||||
return fileIOEvent{
|
||||
SourcePath: sourcePath,
|
||||
TargetPath: targetPath,
|
||||
Type: "IMPORT_STARTED",
|
||||
}
|
||||
}
|
||||
|
||||
func newImportCompleteEvent(sourcePath, targetPath string) fileIOEvent {
|
||||
return fileIOEvent{
|
||||
Type: "IMPORT_COMPLETE",
|
||||
}
|
||||
}
|
||||
|
||||
func newUploadCompleteEvent(sourcePath, targetPath string) fileIOEvent {
|
||||
return fileIOEvent{
|
||||
SourcePath: sourcePath,
|
||||
TargetPath: targetPath,
|
||||
Type: "UPLOAD_COMPLETE",
|
||||
}
|
||||
}
|
||||
|
||||
func renderSyncEvents(ctx context.Context, eventChannel <-chan sync.Event, syncer *sync.Sync) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case e, ok := <-eventChannel:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// We parse progress events from the sync to track when file uploads
|
||||
// are complete and emit the corresponding events
|
||||
if e.String() != "" && e.Type() == sync.EventTypeProgress {
|
||||
progressEvent := e.(*sync.EventSyncProgress)
|
||||
if progressEvent.Progress < 1 {
|
||||
return nil
|
||||
}
|
||||
// TODO: test this works with windows paths
|
||||
remotePath, err := syncer.RemotePath(progressEvent.Path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = cmdio.Render(ctx, newUploadCompleteEvent(progressEvent.Path, remotePath))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -49,6 +49,10 @@ func renderJson(w io.Writer, v any) error {
|
|||
return err
|
||||
}
|
||||
_, err = w.Write(pretty)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write([]byte("\n"))
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ const (
|
|||
|
||||
type Event interface {
|
||||
fmt.Stringer
|
||||
Type() EventType
|
||||
}
|
||||
|
||||
type EventBase struct {
|
||||
|
@ -73,6 +74,10 @@ func (e *EventStart) String() string {
|
|||
return fmt.Sprintf("Action: %s", e.EventChanges.String())
|
||||
}
|
||||
|
||||
func (e *EventStart) Type() EventType {
|
||||
return EventTypeStart
|
||||
}
|
||||
|
||||
func newEventStart(seq int, put []string, delete []string) Event {
|
||||
return &EventStart{
|
||||
EventBase: newEventBase(seq, EventTypeStart),
|
||||
|
@ -106,6 +111,10 @@ func (e *EventSyncProgress) String() string {
|
|||
}
|
||||
}
|
||||
|
||||
func (e *EventSyncProgress) Type() EventType {
|
||||
return EventTypeProgress
|
||||
}
|
||||
|
||||
func newEventProgress(seq int, action EventAction, path string, progress float32) Event {
|
||||
return &EventSyncProgress{
|
||||
EventBase: newEventBase(seq, EventTypeProgress),
|
||||
|
@ -133,6 +142,10 @@ func (e *EventSyncComplete) String() string {
|
|||
return "Complete"
|
||||
}
|
||||
|
||||
func (e *EventSyncComplete) Type() EventType {
|
||||
return EventTypeComplete
|
||||
}
|
||||
|
||||
func newEventComplete(seq int, put []string, delete []string) Event {
|
||||
return &EventSyncComplete{
|
||||
EventBase: newEventBase(seq, EventTypeComplete),
|
||||
|
|
|
@ -17,23 +17,30 @@ import (
|
|||
"github.com/databricks/databricks-sdk-go/service/workspace"
|
||||
)
|
||||
|
||||
type RepoFileOptions struct {
|
||||
OverwriteIfExists bool
|
||||
}
|
||||
|
||||
// RepoFiles wraps reading and writing into a remote repo with safeguards to prevent
|
||||
// accidental deletion of repos and more robust methods to overwrite workspace files
|
||||
type RepoFiles struct {
|
||||
*RepoFileOptions
|
||||
|
||||
repoRoot string
|
||||
localRoot string
|
||||
workspaceClient *databricks.WorkspaceClient
|
||||
}
|
||||
|
||||
func Create(repoRoot, localRoot string, workspaceClient *databricks.WorkspaceClient) *RepoFiles {
|
||||
func Create(repoRoot, localRoot string, workspaceClient *databricks.WorkspaceClient, opts *RepoFileOptions) *RepoFiles {
|
||||
return &RepoFiles{
|
||||
repoRoot: repoRoot,
|
||||
localRoot: localRoot,
|
||||
workspaceClient: workspaceClient,
|
||||
RepoFileOptions: opts,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RepoFiles) remotePath(relativePath string) (string, error) {
|
||||
func (r *RepoFiles) RemotePath(relativePath string) (string, error) {
|
||||
fullPath := path.Join(r.repoRoot, relativePath)
|
||||
cleanFullPath := path.Clean(fullPath)
|
||||
if !strings.HasPrefix(cleanFullPath, r.repoRoot) {
|
||||
|
@ -58,12 +65,12 @@ func (r *RepoFiles) writeRemote(ctx context.Context, relativePath string, conten
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
remotePath, err := r.remotePath(relativePath)
|
||||
remotePath, err := r.RemotePath(relativePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
escapedPath := url.PathEscape(strings.TrimLeft(remotePath, "/"))
|
||||
apiPath := fmt.Sprintf("/api/2.0/workspace-files/import-file/%s?overwrite=true", escapedPath)
|
||||
apiPath := fmt.Sprintf("/api/2.0/workspace-files/import-file/%s?overwrite=%t", escapedPath, r.OverwriteIfExists)
|
||||
|
||||
err = apiClient.Do(ctx, http.MethodPost, apiPath, content, nil)
|
||||
|
||||
|
@ -113,7 +120,7 @@ func (r *RepoFiles) writeRemote(ctx context.Context, relativePath string, conten
|
|||
}
|
||||
|
||||
func (r *RepoFiles) deleteRemote(ctx context.Context, relativePath string) error {
|
||||
remotePath, err := r.remotePath(relativePath)
|
||||
remotePath, err := r.RemotePath(relativePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -10,68 +10,68 @@ import (
|
|||
|
||||
func TestRepoFilesRemotePath(t *testing.T) {
|
||||
repoRoot := "/Repos/doraemon/bar"
|
||||
repoFiles := Create(repoRoot, "/doraemon/foo/bar", nil)
|
||||
repoFiles := Create(repoRoot, "/doraemon/foo/bar", nil, nil)
|
||||
|
||||
remotePath, err := repoFiles.remotePath("a/b/c")
|
||||
remotePath, err := repoFiles.RemotePath("a/b/c")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, repoRoot+"/a/b/c", remotePath)
|
||||
|
||||
remotePath, err = repoFiles.remotePath("a/b/../d")
|
||||
remotePath, err = repoFiles.RemotePath("a/b/../d")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, repoRoot+"/a/d", remotePath)
|
||||
|
||||
remotePath, err = repoFiles.remotePath("a/../c")
|
||||
remotePath, err = repoFiles.RemotePath("a/../c")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, repoRoot+"/c", remotePath)
|
||||
|
||||
remotePath, err = repoFiles.remotePath("a/b/c/.")
|
||||
remotePath, err = repoFiles.RemotePath("a/b/c/.")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, repoRoot+"/a/b/c", remotePath)
|
||||
|
||||
remotePath, err = repoFiles.remotePath("a/b/c/d/./../../f/g")
|
||||
remotePath, err = repoFiles.RemotePath("a/b/c/d/./../../f/g")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, repoRoot+"/a/b/f/g", remotePath)
|
||||
|
||||
_, err = repoFiles.remotePath("..")
|
||||
_, err = repoFiles.RemotePath("..")
|
||||
assert.ErrorContains(t, err, `relative file path is not inside repo root: ..`)
|
||||
|
||||
_, err = repoFiles.remotePath("a/../..")
|
||||
_, err = repoFiles.RemotePath("a/../..")
|
||||
assert.ErrorContains(t, err, `relative file path is not inside repo root: a/../..`)
|
||||
|
||||
_, err = repoFiles.remotePath("./../.")
|
||||
_, err = repoFiles.RemotePath("./../.")
|
||||
assert.ErrorContains(t, err, `relative file path is not inside repo root: ./../.`)
|
||||
|
||||
_, err = repoFiles.remotePath("/./.././..")
|
||||
_, err = repoFiles.RemotePath("/./.././..")
|
||||
assert.ErrorContains(t, err, `relative file path is not inside repo root: /./.././..`)
|
||||
|
||||
_, err = repoFiles.remotePath("./../.")
|
||||
_, err = repoFiles.RemotePath("./../.")
|
||||
assert.ErrorContains(t, err, `relative file path is not inside repo root: ./../.`)
|
||||
|
||||
_, err = repoFiles.remotePath("./..")
|
||||
_, err = repoFiles.RemotePath("./..")
|
||||
assert.ErrorContains(t, err, `relative file path is not inside repo root: ./..`)
|
||||
|
||||
_, err = repoFiles.remotePath("./../../..")
|
||||
_, err = repoFiles.RemotePath("./../../..")
|
||||
assert.ErrorContains(t, err, `relative file path is not inside repo root: ./../../..`)
|
||||
|
||||
_, err = repoFiles.remotePath("./../a/./b../../..")
|
||||
_, err = repoFiles.RemotePath("./../a/./b../../..")
|
||||
assert.ErrorContains(t, err, `relative file path is not inside repo root: ./../a/./b../../..`)
|
||||
|
||||
_, err = repoFiles.remotePath("../..")
|
||||
_, err = repoFiles.RemotePath("../..")
|
||||
assert.ErrorContains(t, err, `relative file path is not inside repo root: ../..`)
|
||||
|
||||
_, err = repoFiles.remotePath(".//a/..//./b/..")
|
||||
_, err = repoFiles.RemotePath(".//a/..//./b/..")
|
||||
assert.ErrorContains(t, err, `file path relative to repo root cannot be empty`)
|
||||
|
||||
_, err = repoFiles.remotePath("a/b/../..")
|
||||
_, err = repoFiles.RemotePath("a/b/../..")
|
||||
assert.ErrorContains(t, err, "file path relative to repo root cannot be empty")
|
||||
|
||||
_, err = repoFiles.remotePath("")
|
||||
_, err = repoFiles.RemotePath("")
|
||||
assert.ErrorContains(t, err, "file path relative to repo root cannot be empty")
|
||||
|
||||
_, err = repoFiles.remotePath(".")
|
||||
_, err = repoFiles.RemotePath(".")
|
||||
assert.ErrorContains(t, err, "file path relative to repo root cannot be empty")
|
||||
|
||||
_, err = repoFiles.remotePath("/")
|
||||
_, err = repoFiles.RemotePath("/")
|
||||
assert.ErrorContains(t, err, "file path relative to repo root cannot be empty")
|
||||
}
|
||||
|
||||
|
@ -81,7 +81,7 @@ func TestRepoReadLocal(t *testing.T) {
|
|||
err := os.WriteFile(helloPath, []byte("my name is doraemon :P"), os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
|
||||
repoFiles := Create("/Repos/doraemon/bar", tempDir, nil)
|
||||
repoFiles := Create("/Repos/doraemon/bar", tempDir, nil, nil)
|
||||
bytes, err := repoFiles.readLocal("./a/../hello.txt")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "my name is doraemon :P", string(bytes))
|
||||
|
|
|
@ -24,6 +24,10 @@ type SyncOptions struct {
|
|||
WorkspaceClient *databricks.WorkspaceClient
|
||||
|
||||
Host string
|
||||
|
||||
// If set, sync will not be able to overwrite any existing paths on the
|
||||
// workspace file system.
|
||||
DisallowOverwrites bool
|
||||
}
|
||||
|
||||
type Sync struct {
|
||||
|
@ -76,8 +80,9 @@ func New(ctx context.Context, opts SyncOptions) (*Sync, error) {
|
|||
return nil, fmt.Errorf("unable to load sync snapshot: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
repoFiles := repofiles.Create(opts.RemotePath, opts.LocalPath, opts.WorkspaceClient)
|
||||
repoFiles := repofiles.Create(opts.RemotePath, opts.LocalPath, opts.WorkspaceClient, &repofiles.RepoFileOptions{
|
||||
OverwriteIfExists: !opts.DisallowOverwrites,
|
||||
})
|
||||
|
||||
return &Sync{
|
||||
SyncOptions: &opts,
|
||||
|
@ -125,6 +130,14 @@ func (s *Sync) notifyComplete(ctx context.Context, d diff) {
|
|||
s.seq++
|
||||
}
|
||||
|
||||
func (s *Sync) RemotePath(localPath string) (string, error) {
|
||||
relativePath, ok := s.snapshot.LocalToRemoteNames[localPath]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("could not find remote path for %s", localPath)
|
||||
}
|
||||
return s.repoFiles.RemotePath(relativePath)
|
||||
}
|
||||
|
||||
func (s *Sync) RunOnce(ctx context.Context) error {
|
||||
// tradeoff: doing portable monitoring only due to macOS max descriptor manual ulimit setting requirement
|
||||
// https://github.com/gorakhargosh/watchdog/blob/master/src/watchdog/observers/kqueue.py#L394-L418
|
||||
|
|
Loading…
Reference in New Issue