diff --git a/cmd/sync/snapshot.go b/cmd/sync/snapshot.go index 14712a4e..65a6306b 100644 --- a/cmd/sync/snapshot.go +++ b/cmd/sync/snapshot.go @@ -1,12 +1,14 @@ package sync import ( + "bufio" "context" "encoding/json" "fmt" "io" "os" "path/filepath" + "regexp" "strings" "time" @@ -38,6 +40,13 @@ type Snapshot struct { // key: relative file path from project root // value: last time the remote instance of this file was updated LastUpdatedTimes map[string]time.Time `json:"last_modified_times"` + // This map maps local file names to their remote names + // eg. notebook named "foo.py" locally would be stored as "foo", thus this + // map will contain an entry "foo.py" -> "foo" + LocalToRemoteNames map[string]string `json:"local_to_remote_names"` + // Inverse of localToRemoteNames. Together the form a bijective mapping (ie + // there is a 1:1 unique mapping between local and remote name) + RemoteToLocalNames map[string]string `json:"remote_to_local_names"` } type diff struct { @@ -87,9 +96,11 @@ func newSnapshot(ctx context.Context, remotePath string) (*Snapshot, error) { } return &Snapshot{ - Host: host, - RemotePath: remotePath, - LastUpdatedTimes: make(map[string]time.Time), + Host: host, + RemotePath: remotePath, + LastUpdatedTimes: make(map[string]time.Time), + LocalToRemoteNames: make(map[string]string), + RemoteToLocalNames: make(map[string]string), }, nil } @@ -161,9 +172,37 @@ func (d diff) String() string { return strings.Join(changes, ", ") } -func (s Snapshot) diff(all []git.File) (change diff) { +func getNotebookDetails(path string) (isNotebook bool, typeOfNotebook string, err error) { + isNotebook = false + typeOfNotebook = "" + + isPythonFile, err := regexp.Match(`\.py$`, []byte(path)) + if err != nil { + return + } + if isPythonFile { + f, err := os.Open(path) + if err != nil { + return false, "", err + } + defer f.Close() + scanner := bufio.NewScanner(f) + ok := scanner.Scan() + if !ok { + return false, "", scanner.Err() + } + // A python file is a notebook if it starts with the following magic string + isNotebook = strings.Contains(scanner.Text(), "# Databricks notebook source") + return isNotebook, "PYTHON", nil + } + return false, "", nil +} + +func (s Snapshot) diff(all []git.File) (change diff, err error) { currentFilenames := map[string]bool{} lastModifiedTimes := s.LastUpdatedTimes + remoteToLocalNames := s.RemoteToLocalNames + localToRemoteNames := s.LocalToRemoteNames for _, f := range all { // create set of current files to figure out if removals are needed currentFilenames[f.Relative] = true @@ -174,22 +213,52 @@ func (s Snapshot) diff(all []git.File) (change diff) { if !seen || modified.After(lastSeenModified) { change.put = append(change.put, f.Relative) lastModifiedTimes[f.Relative] = modified + + // Keep track of remote names of local files so we can delete them later + isNotebook, typeOfNotebook, err := getNotebookDetails(f.Absolute) + if err != nil { + return change, err + } + remoteName := f.Relative + if isNotebook && typeOfNotebook == "PYTHON" { + remoteName = strings.TrimSuffix(f.Relative, `.py`) + } + + // If the remote handle of a file changes, we want to delete the old + // remote version of that file to avoid duplicates. + // This can happen if a python notebook is converted to a python + // script or vice versa + oldRemoteName, ok := localToRemoteNames[f.Relative] + if ok && oldRemoteName != remoteName { + change.delete = append(change.delete, oldRemoteName) + delete(remoteToLocalNames, oldRemoteName) + } + // We cannot allow two local files in the project to point to the same + // remote path + oldLocalName, ok := remoteToLocalNames[remoteName] + if ok && oldLocalName != f.Relative { + return change, fmt.Errorf("both %s and %s point to the same remote file location %s. Please remove one of them from your local project", oldLocalName, f.Relative, remoteName) + } + localToRemoteNames[f.Relative] = remoteName + remoteToLocalNames[remoteName] = f.Relative } } - // figure out files in the snapshot, but not on local filesystem - for relative := range lastModifiedTimes { - _, exists := currentFilenames[relative] + // figure out files in the snapshot.lastModifiedTimes, but not on local + // filesystem. These will be deleted + for localName := range lastModifiedTimes { + _, exists := currentFilenames[localName] if exists { continue } // add them to a delete batch - change.delete = append(change.delete, relative) - // remove the file from snapshot - delete(lastModifiedTimes, relative) + change.delete = append(change.delete, localToRemoteNames[localName]) } // and remove them from the snapshot - for _, v := range change.delete { - delete(lastModifiedTimes, v) + for _, remoteName := range change.delete { + localName := remoteToLocalNames[remoteName] + delete(lastModifiedTimes, localName) + delete(remoteToLocalNames, remoteName) + delete(localToRemoteNames, localName) } return } diff --git a/cmd/sync/snapshot_test.go b/cmd/sync/snapshot_test.go index ef945008..49665c4a 100644 --- a/cmd/sync/snapshot_test.go +++ b/cmd/sync/snapshot_test.go @@ -10,57 +10,220 @@ import ( "github.com/stretchr/testify/assert" ) +type testFile struct { + mtime time.Time + fd *os.File + path string +} + +func createFile(t *testing.T, path string) *testFile { + f, err := os.Create(path) + assert.NoError(t, err) + + fileInfo, err := os.Stat(path) + assert.NoError(t, err) + + return &testFile{ + path: path, + fd: f, + mtime: fileInfo.ModTime(), + } +} + +func (f *testFile) close(t *testing.T) { + err := f.fd.Close() + assert.NoError(t, err) +} + +func (f *testFile) overwrite(t *testing.T, s string) { + err := os.Truncate(f.path, 0) + assert.NoError(t, err) + + _, err = f.fd.Seek(0, 0) + assert.NoError(t, err) + + _, err = f.fd.WriteString(s) + assert.NoError(t, err) + + // We manually update mtime after write because github actions file + // system does not :') + err = os.Chtimes(f.path, f.mtime.Add(time.Minute), f.mtime.Add(time.Minute)) + assert.NoError(t, err) + f.mtime = f.mtime.Add(time.Minute) +} + +func (f *testFile) remove(t *testing.T) { + err := os.Remove(f.path) + assert.NoError(t, err) +} + +func assertKeysOfMap(t *testing.T, m map[string]time.Time, expectedKeys []string) { + keys := make([]string, len(m)) + i := 0 + for k := range m { + keys[i] = k + i++ + } + assert.ElementsMatch(t, expectedKeys, keys) +} + func TestDiff(t *testing.T) { // Create temp project dir projectDir := t.TempDir() - - f1, err := os.Create(filepath.Join(projectDir, "hello.txt")) - assert.NoError(t, err) - defer f1.Close() - worldFilePath := filepath.Join(projectDir, "world.txt") - f2, err := os.Create(worldFilePath) - assert.NoError(t, err) - defer f2.Close() - fileSet := git.NewFileSet(projectDir) + state := Snapshot{ + LastUpdatedTimes: make(map[string]time.Time), + LocalToRemoteNames: make(map[string]string), + RemoteToLocalNames: make(map[string]string), + } + + f1 := createFile(t, filepath.Join(projectDir, "hello.txt")) + defer f1.close(t) + worldFilePath := filepath.Join(projectDir, "world.txt") + f2 := createFile(t, worldFilePath) + defer f2.close(t) + + // New files are put files, err := fileSet.All() assert.NoError(t, err) - state := Snapshot{ - LastUpdatedTimes: make(map[string]time.Time), - } - change := state.diff(files) - - // New files are added to put + change, err := state.diff(files) + assert.NoError(t, err) assert.Len(t, change.delete, 0) assert.Len(t, change.put, 2) assert.Contains(t, change.put, "hello.txt") assert.Contains(t, change.put, "world.txt") + assertKeysOfMap(t, state.LastUpdatedTimes, []string{"hello.txt", "world.txt"}) + assert.Equal(t, map[string]string{"hello.txt": "hello.txt", "world.txt": "world.txt"}, state.LocalToRemoteNames) + assert.Equal(t, map[string]string{"hello.txt": "hello.txt", "world.txt": "world.txt"}, state.RemoteToLocalNames) - // Edited files are added to put. - // File system in the github actions env does not update - // mtime on writes to a file. So we are manually editting it - // instead of writing to world.txt - worldFileInfo, err := os.Stat(worldFilePath) - assert.NoError(t, err) - os.Chtimes(worldFilePath, - worldFileInfo.ModTime().Add(time.Nanosecond), - worldFileInfo.ModTime().Add(time.Nanosecond)) - + // world.txt is editted + f2.overwrite(t, "bunnies are cute.") assert.NoError(t, err) files, err = fileSet.All() assert.NoError(t, err) - change = state.diff(files) + change, err = state.diff(files) + assert.NoError(t, err) + assert.Len(t, change.delete, 0) assert.Len(t, change.put, 1) assert.Contains(t, change.put, "world.txt") + assertKeysOfMap(t, state.LastUpdatedTimes, []string{"hello.txt", "world.txt"}) + assert.Equal(t, map[string]string{"hello.txt": "hello.txt", "world.txt": "world.txt"}, state.LocalToRemoteNames) + assert.Equal(t, map[string]string{"hello.txt": "hello.txt", "world.txt": "world.txt"}, state.RemoteToLocalNames) - // Removed files are added to delete - err = os.Remove(filepath.Join(projectDir, "hello.txt")) + // hello.txt is deleted + f1.remove(t) assert.NoError(t, err) files, err = fileSet.All() assert.NoError(t, err) - change = state.diff(files) + change, err = state.diff(files) + assert.NoError(t, err) assert.Len(t, change.delete, 1) assert.Len(t, change.put, 0) assert.Contains(t, change.delete, "hello.txt") + assertKeysOfMap(t, state.LastUpdatedTimes, []string{"world.txt"}) + assert.Equal(t, map[string]string{"world.txt": "world.txt"}, state.LocalToRemoteNames) + assert.Equal(t, map[string]string{"world.txt": "world.txt"}, state.RemoteToLocalNames) +} + +func TestPythonNotebookDiff(t *testing.T) { + // Create temp project dir + projectDir := t.TempDir() + fileSet := git.NewFileSet(projectDir) + state := Snapshot{ + LastUpdatedTimes: make(map[string]time.Time), + LocalToRemoteNames: make(map[string]string), + RemoteToLocalNames: make(map[string]string), + } + + foo := createFile(t, filepath.Join(projectDir, "foo.py")) + defer foo.close(t) + + // Case 1: notebook foo.py is uploaded + files, err := fileSet.All() + assert.NoError(t, err) + foo.overwrite(t, "# Databricks notebook source\nprint(\"abc\")") + change, err := state.diff(files) + assert.NoError(t, err) + assert.Len(t, change.delete, 0) + assert.Len(t, change.put, 1) + assert.Contains(t, change.put, "foo.py") + assertKeysOfMap(t, state.LastUpdatedTimes, []string{"foo.py"}) + assert.Equal(t, map[string]string{"foo.py": "foo"}, state.LocalToRemoteNames) + assert.Equal(t, map[string]string{"foo": "foo.py"}, state.RemoteToLocalNames) + + // Case 2: notebook foo.py is converted to python script by removing + // magic keyword + foo.overwrite(t, "print(\"abc\")") + files, err = fileSet.All() + assert.NoError(t, err) + change, err = state.diff(files) + assert.NoError(t, err) + assert.Len(t, change.delete, 1) + assert.Len(t, change.put, 1) + assert.Contains(t, change.put, "foo.py") + assert.Contains(t, change.delete, "foo") + assertKeysOfMap(t, state.LastUpdatedTimes, []string{"foo.py"}) + assert.Equal(t, map[string]string{"foo.py": "foo.py"}, state.LocalToRemoteNames) + assert.Equal(t, map[string]string{"foo.py": "foo.py"}, state.RemoteToLocalNames) + + // Case 3: Python script foo.py is converted to a databricks notebook + foo.overwrite(t, "# Databricks notebook source\nprint(\"def\")") + files, err = fileSet.All() + assert.NoError(t, err) + change, err = state.diff(files) + assert.NoError(t, err) + assert.Len(t, change.delete, 1) + assert.Len(t, change.put, 1) + assert.Contains(t, change.put, "foo.py") + assert.Contains(t, change.delete, "foo.py") + assertKeysOfMap(t, state.LastUpdatedTimes, []string{"foo.py"}) + assert.Equal(t, map[string]string{"foo.py": "foo"}, state.LocalToRemoteNames) + assert.Equal(t, map[string]string{"foo": "foo.py"}, state.RemoteToLocalNames) + + // Case 4: Python notebook foo.py is deleted, and its remote name is used in change.delete + foo.remove(t) + assert.NoError(t, err) + files, err = fileSet.All() + assert.NoError(t, err) + change, err = state.diff(files) + assert.NoError(t, err) + assert.Len(t, change.delete, 1) + assert.Len(t, change.put, 0) + assert.Contains(t, change.delete, "foo") + assert.Len(t, state.LastUpdatedTimes, 0) + assert.Equal(t, map[string]string{}, state.LocalToRemoteNames) + assert.Equal(t, map[string]string{}, state.RemoteToLocalNames) +} + +func TestErrorWhenIdenticalRemoteName(t *testing.T) { + // Create temp project dir + projectDir := t.TempDir() + fileSet := git.NewFileSet(projectDir) + state := Snapshot{ + LastUpdatedTimes: make(map[string]time.Time), + LocalToRemoteNames: make(map[string]string), + RemoteToLocalNames: make(map[string]string), + } + + // upload should work since they point to different destinations + pythonFoo := createFile(t, filepath.Join(projectDir, "foo.py")) + defer pythonFoo.close(t) + vanillaFoo := createFile(t, filepath.Join(projectDir, "foo")) + defer vanillaFoo.close(t) + files, err := fileSet.All() + assert.NoError(t, err) + change, err := state.diff(files) + assert.NoError(t, err) + assert.Len(t, change.delete, 0) + assert.Len(t, change.put, 2) + assert.Contains(t, change.put, "foo.py") + assert.Contains(t, change.put, "foo") + + // errors out because they point to the same destination + pythonFoo.overwrite(t, "# Databricks notebook source\nprint(\"def\")") + files, err = fileSet.All() + assert.NoError(t, err) + change, err = state.diff(files) + assert.ErrorContains(t, err, "both foo and foo.py point to the same remote file location foo. Please remove one of them from your local project") } diff --git a/cmd/sync/watchdog.go b/cmd/sync/watchdog.go index a05a9c7b..19e55a32 100644 --- a/cmd/sync/watchdog.go +++ b/cmd/sync/watchdog.go @@ -31,6 +31,18 @@ type watchdog struct { // rate limits const MaxRequestsInFlight = 20 +// path: The local path of the file in the local file system +// +// The API calls for a python script foo.py would be +// `PUT foo.py` +// `DELETE foo.py` +// +// The API calls for a python notebook foo.py would be +// `PUT foo.py` +// `DELETE foo` +// +// The workspace file system backend strips .py from the file name if the python +// file is a notebook func putFile(ctx context.Context, path string, content io.Reader) error { wsc := project.Get(ctx).WorkspacesClient() // workspace mkdirs is idempotent @@ -48,6 +60,7 @@ func putFile(ctx context.Context, path string, content io.Reader) error { return apiClient.Post(ctx, apiPath, content, nil) } +// path: The remote path of the file in the workspace func deleteFile(ctx context.Context, path string, wsc *workspaces.WorkspacesClient) error { err := wsc.Workspace.Delete(ctx, workspace.Delete{ @@ -75,29 +88,30 @@ func getRemoteSyncCallback(ctx context.Context, root, remoteDir string, wsc *wor // Allow MaxRequestLimit maxiumum concurrent api calls g.SetLimit(MaxRequestsInFlight) - for _, fileName := range d.delete { - // Copy of fileName created to make this safe for concurrent use. - // directly using fileName can cause race conditions since the loop - // might iterate over to the next fileName before the go routine function + for _, remoteName := range d.delete { + // Copy of remoteName created to make this safe for concurrent use. + // directly using remoteName can cause race conditions since the loop + // might iterate over to the next remoteName before the go routine function // is evaluated - localFileName := fileName + remoteNameCopy := remoteName g.Go(func() error { - err := deleteFile(ctx, path.Join(remoteDir, localFileName), wsc) + err := deleteFile(ctx, path.Join(remoteDir, remoteNameCopy), wsc) if err != nil { return err } - log.Printf("[INFO] Deleted %s", localFileName) + log.Printf("[INFO] Deleted %s", remoteNameCopy) return nil }) } - for _, fileName := range d.put { - localFileName := fileName + for _, localName := range d.put { + // Copy of localName created to make this safe for concurrent use. + localNameCopy := localName g.Go(func() error { - f, err := os.Open(filepath.Join(root, localFileName)) + f, err := os.Open(filepath.Join(root, localNameCopy)) if err != nil { return err } - err = putFile(ctx, path.Join(remoteDir, localFileName), f) + err = putFile(ctx, path.Join(remoteDir, localNameCopy), f) if err != nil { return fmt.Errorf("failed to upload file: %s", err) } @@ -105,7 +119,7 @@ func getRemoteSyncCallback(ctx context.Context, root, remoteDir string, wsc *wor if err != nil { return err } - log.Printf("[INFO] Uploaded %s", localFileName) + log.Printf("[INFO] Uploaded %s", localNameCopy) return nil }) } @@ -162,7 +176,11 @@ func (w *watchdog) main(ctx context.Context, applyDiff func(diff) error, remoteP w.failure = err return } - change := snapshot.diff(all) + change, err := snapshot.diff(all) + if err != nil { + w.failure = err + return + } if change.IsEmpty() { onlyOnceInitLog.Do(func() { log.Printf("[INFO] Initial Sync Complete")