diff --git a/cmd/sync/snapshot.go b/cmd/sync/snapshot.go index 3dd81bc9..da7646e0 100644 --- a/cmd/sync/snapshot.go +++ b/cmd/sync/snapshot.go @@ -1,7 +1,11 @@ package sync import ( + "encoding/json" "fmt" + "io" + "os" + "path/filepath" "strings" "time" @@ -15,6 +19,61 @@ type diff struct { delete []string } +const SyncSnapshotFile = "repo_snapshot.json" +const BricksDir = ".bricks" + +func (s *snapshot) storeSnapshot(root string) error { + // create snapshot file + configDir := filepath.Join(root, BricksDir) + if _, err := os.Stat(configDir); os.IsNotExist(err) { + err = os.Mkdir(configDir, os.ModeDir|os.ModePerm) + if err != nil { + return fmt.Errorf("failed to create config directory: %s", err) + } + } + persistedSnapshotPath := filepath.Join(configDir, SyncSnapshotFile) + f, err := os.OpenFile(persistedSnapshotPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0755) + if err != nil { + return fmt.Errorf("failed to create/open persisted sync snapshot file: %s", err) + } + defer f.Close() + + // persist snapshot to disk + bytes, err := json.MarshalIndent(s, "", " ") + if err != nil { + return fmt.Errorf("failed to json marshal in-memory snapshot: %s", err) + } + _, err = f.Write(bytes) + if err != nil { + return fmt.Errorf("failed to write sync snapshot to disk: %s", err) + } + return nil +} + +func (s *snapshot) loadSnapshot(root string) error { + persistedSnapshotPath := filepath.Join(root, BricksDir, SyncSnapshotFile) + if _, err := os.Stat(persistedSnapshotPath); os.IsNotExist(err) { + return nil + } + + f, err := os.Open(persistedSnapshotPath) + if err != nil { + return fmt.Errorf("failed to open persisted sync snapshot file: %s", err) + } + defer f.Close() + + bytes, err := io.ReadAll(f) + if err != nil { + // clean up these error messages a bit + return fmt.Errorf("failed to read sync snapshot from disk: %s", err) + } + err = json.Unmarshal(bytes, s) + if err != nil { + return fmt.Errorf("failed to json unmarshal persisted snapshot: %s", err) + } + return nil +} + func (d diff) IsEmpty() bool { return len(d.put) == 0 && len(d.delete) == 0 } @@ -41,11 +100,11 @@ func (s snapshot) diff(all []git.File) (change diff) { // get current modified timestamp modified := f.Modified() lastSeenModified, seen := s[f.Relative] - if !(!seen || modified.After(lastSeenModified)) { - continue + + if !seen || modified.After(lastSeenModified) { + change.put = append(change.put, f.Relative) + s[f.Relative] = modified } - change.put = append(change.put, f.Relative) - s[f.Relative] = modified } // figure out files in the snapshot, but not on local filesystem for relative := range s { @@ -53,8 +112,10 @@ func (s snapshot) diff(all []git.File) (change diff) { if exists { continue } - // and add them to a delete batch + // add them to a delete batch change.delete = append(change.delete, relative) + // remove the file from snapshot + delete(s, relative) } // and remove them from the snapshot for _, v := range change.delete { diff --git a/cmd/sync/snapshot_test.go b/cmd/sync/snapshot_test.go new file mode 100644 index 00000000..c22b2994 --- /dev/null +++ b/cmd/sync/snapshot_test.go @@ -0,0 +1,54 @@ +package sync + +import ( + "os" + "path/filepath" + "testing" + + "github.com/databricks/bricks/git" + "github.com/stretchr/testify/assert" +) + +func TestDiff(t *testing.T) { + // Create temp project dir + projectDir := t.TempDir() + + f1, err := os.Create(filepath.Join(projectDir, "hello.txt")) + assert.NoError(t, err) + defer f1.Close() + f2, err := os.Create(filepath.Join(projectDir, "world.txt")) + assert.NoError(t, err) + defer f2.Close() + + fileSet := git.NewFileSet(projectDir) + files, err := fileSet.All() + assert.NoError(t, err) + state := snapshot{} + change := state.diff(files) + + // New files are added to put + assert.Len(t, change.delete, 0) + assert.Len(t, change.put, 2) + assert.Contains(t, change.put, "hello.txt") + assert.Contains(t, change.put, "world.txt") + + // Edited files are added to put + _, err = f2.WriteString("I like clis") + assert.NoError(t, err) + files, err = fileSet.All() + assert.NoError(t, err) + change = state.diff(files) + assert.Len(t, change.delete, 0) + assert.Len(t, change.put, 1) + assert.Contains(t, change.put, "world.txt") + + // Removed files are added to delete + err = os.Remove(filepath.Join(projectDir, "hello.txt")) + assert.NoError(t, err) + files, err = fileSet.All() + assert.NoError(t, err) + change = state.diff(files) + assert.Len(t, change.delete, 1) + assert.Len(t, change.put, 0) + assert.Contains(t, change.delete, "hello.txt") +} diff --git a/cmd/sync/sync.go b/cmd/sync/sync.go index 5c7a2c0a..8414ebf5 100644 --- a/cmd/sync/sync.go +++ b/cmd/sync/sync.go @@ -3,14 +3,11 @@ package sync import ( "fmt" "log" - "path" "time" "github.com/databricks/bricks/cmd/root" "github.com/databricks/bricks/git" "github.com/databricks/bricks/project" - "github.com/databricks/bricks/utilities" - "github.com/databricks/databricks-sdk-go/service/workspace" "github.com/spf13/cobra" ) @@ -19,57 +16,47 @@ var syncCmd = &cobra.Command{ Use: "sync", Short: "run syncs for the project", RunE: func(cmd *cobra.Command, args []string) error { - origin, err := git.HttpsOrigin() - if err != nil { - return err - } - log.Printf("[INFO] %s", origin) ctx := cmd.Context() - wsc := project.Current.WorkspacesClient() - checkouts, err := utilities.GetAllRepos(ctx, wsc, "/") - if err != nil { - return err - } - for _, v := range checkouts { - log.Printf("[INFO] %s", v.Url) - } - me := project.Current.Me() - repositoryName, err := git.RepositoryName() - if err != nil { - return err - } - base := fmt.Sprintf("/Repos/%s/%s", me.UserName, repositoryName) - return watchForChanges(ctx, git.MustGetFileSet(), *interval, func(d diff) error { - for _, v := range d.delete { - err := wsc.Workspace.Delete(ctx, - workspace.DeleteRequest{ - Path: path.Join(base, v), - Recursive: true, - }, - ) - if err != nil { - return err - } + + if *remotePath == "" { + me, err := project.Current.Me() + if err != nil { + return err } - return nil - }) + repositoryName, err := git.RepositoryName() + if err != nil { + return err + } + *remotePath = fmt.Sprintf("/Repos/%s/%s", me.UserName, repositoryName) + } + + log.Printf("[INFO] Remote file sync location: %v", *remotePath) + repoExists, err := git.RepoExists(*remotePath, ctx, wsc) + if err != nil { + return err + } + if !repoExists { + return fmt.Errorf("repo not found, please ensure %s exists", *remotePath) + } + + fileSet, err := git.GetFileSet() + if err != nil { + return err + } + syncCallback := getRemoteSyncCallback(ctx, *remotePath, wsc) + err = spawnSyncRoutine(ctx, fileSet, *interval, syncCallback) + return err }, } -// func ImportFile(ctx context.Context, path string, content io.Reader) error { -// client := project.Current.Client() -// apiPath := fmt.Sprintf( -// "/workspace-files/import-file/%s?overwrite=true", -// strings.TrimLeft(path, "/")) -// // TODO: change upstream client to support io.Reader as body -// return client.Post(ctx, apiPath, content, nil) -// } - // project files polling interval var interval *time.Duration +var remotePath *string + func init() { root.RootCmd.AddCommand(syncCmd) interval = syncCmd.Flags().Duration("interval", 1*time.Second, "project files polling interval") + remotePath = syncCmd.Flags().String("remote-path", "", "remote path to store repo in. eg: /Repos/me@example.com/test-repo") } diff --git a/cmd/sync/watchdog.go b/cmd/sync/watchdog.go index e186d9c1..db57ec9b 100644 --- a/cmd/sync/watchdog.go +++ b/cmd/sync/watchdog.go @@ -2,11 +2,21 @@ package sync import ( "context" + "fmt" + "io" "log" + "os" + "path" + "path/filepath" + "strings" "sync" "time" "github.com/databricks/bricks/git" + "github.com/databricks/bricks/project" + "github.com/databricks/databricks-sdk-go/databricks/client" + "github.com/databricks/databricks-sdk-go/service/workspace" + "github.com/databricks/databricks-sdk-go/workspaces" ) type watchdog struct { @@ -16,24 +26,86 @@ type watchdog struct { failure error // data race? make channel? } -func watchForChanges(ctx context.Context, files git.FileSet, - interval time.Duration, cb func(diff) error) error { +func putFile(ctx context.Context, path string, content io.Reader) error { + wsc := project.Current.WorkspacesClient() + // workspace mkdirs is idempotent + err := wsc.Workspace.MkdirsByPath(ctx, filepath.Dir(path)) + if err != nil { + return fmt.Errorf("could not mkdir to put file: %s", err) + } + apiClient := client.New(wsc.Config) + apiPath := fmt.Sprintf( + "/api/2.0/workspace-files/import-file/%s?overwrite=true", + strings.TrimLeft(path, "/")) + return apiClient.Post(ctx, apiPath, content, nil) +} + +func getRemoteSyncCallback(ctx context.Context, remoteDir string, wsc *workspaces.WorkspacesClient) func(localDiff diff) error { + return func(d diff) error { + for _, filePath := range d.delete { + err := wsc.Workspace.Delete(ctx, + workspace.DeleteRequest{ + Path: path.Join(remoteDir, filePath), + Recursive: true, + }, + ) + if err != nil { + return err + } + log.Printf("[INFO] Deleted %s", filePath) + } + for _, filePath := range d.put { + f, err := os.Open(filePath) + if err != nil { + return err + } + err = putFile(ctx, path.Join(remoteDir, filePath), f) + if err != nil { + return fmt.Errorf("failed to upload file: %s", err) // TODO: fmt.Errorf + } + err = f.Close() + if err != nil { + return err // TODO: fmt.Errorf + } + log.Printf("[INFO] Uploaded %s", filePath) + } + return nil + } +} + +func spawnSyncRoutine(ctx context.Context, + files git.FileSet, + interval time.Duration, + applyDiff func(diff) error) error { w := &watchdog{ files: files, ticker: time.NewTicker(interval), } w.wg.Add(1) - go w.main(ctx, cb) + go w.main(ctx, applyDiff) w.wg.Wait() return w.failure } // tradeoff: doing portable monitoring only due to macOS max descriptor manual ulimit setting requirement // https://github.com/gorakhargosh/watchdog/blob/master/src/watchdog/observers/kqueue.py#L394-L418 -func (w *watchdog) main(ctx context.Context, cb func(diff) error) { +func (w *watchdog) main(ctx context.Context, applyDiff func(diff) error) { defer w.wg.Done() // load from json or sync it every time there's an action state := snapshot{} + root, err := git.Root() + if err != nil { + log.Printf("[ERROR] cannot find project root: %s", err) + w.failure = err + return + } + err = state.loadSnapshot(root) + if err != nil { + log.Printf("[ERROR] cannot load snapshot: %s", err) + w.failure = err + return + } + for { select { case <-ctx.Done(): @@ -50,11 +122,17 @@ func (w *watchdog) main(ctx context.Context, cb func(diff) error) { continue } log.Printf("[INFO] Action: %v", change) - err = cb(change) + err = applyDiff(change) if err != nil { w.failure = err return } + err = state.storeSnapshot(root) + if err != nil { + log.Printf("[ERROR] cannot store snapshot: %s", err) + w.failure = err + return + } } } } diff --git a/git/fileset.go b/git/fileset.go index 4259c068..c2c83ee9 100644 --- a/git/fileset.go +++ b/git/fileset.go @@ -25,25 +25,27 @@ func (f File) Modified() (ts time.Time) { return info.ModTime() } -// FileSet facilitates fast recursive file listing with -// respect to patterns defined in `.gitignore` file +// FileSet facilitates fast recursive tracked file listing +// with respect to patterns defined in `.gitignore` file +// +// root: Root of the git repository +// ignore: List of patterns defined in `.gitignore`. +// We do not sync files that match this pattern type FileSet struct { root string ignore *ignore.GitIgnore } -// MustGetFileSet retrieves FileSet from Git repository checkout root +// GetFileSet retrieves FileSet from Git repository checkout root // or panics if no root is detected. -func MustGetFileSet() FileSet { +func GetFileSet() (FileSet, error) { root, err := Root() - if err != nil { - panic(err) - } - return New(root) + return NewFileSet(root), err } -func New(root string) FileSet { - lines := []string{".git"} +// Retuns FileSet for the repository located at `root` +func NewFileSet(root string) FileSet { + lines := []string{".git", ".bricks"} rawIgnore, err := os.ReadFile(fmt.Sprintf("%s/.gitignore", root)) if err == nil { // add entries from .gitignore if the file exists (did read correctly) @@ -59,11 +61,15 @@ func New(root string) FileSet { } } +// Return all tracked files for Repo func (w *FileSet) All() ([]File, error) { - return w.RecursiveChildren(w.root) + return w.RecursiveListFiles(w.root) } -func (w *FileSet) RecursiveChildren(dir string) (found []File, err error) { +// Recursively traverses dir in a depth first manner and returns a list of all files +// that are being tracked in the FileSet (ie not being ignored for matching one of the +// patterns in w.ignore) +func (w *FileSet) RecursiveListFiles(dir string) (fileList []File, err error) { queue, err := readDir(dir, w.root) if err != nil { return nil, err @@ -75,7 +81,7 @@ func (w *FileSet) RecursiveChildren(dir string) (found []File, err error) { continue } if !current.IsDir() { - found = append(found, current) + fileList = append(fileList, current) continue } children, err := readDir(current.Absolute, w.root) @@ -84,7 +90,7 @@ func (w *FileSet) RecursiveChildren(dir string) (found []File, err error) { } queue = append(queue, children...) } - return found, nil + return fileList, nil } func readDir(dir, root string) (queue []File, err error) { diff --git a/git/fileset_test.go b/git/fileset_test.go new file mode 100644 index 00000000..94b114a3 --- /dev/null +++ b/git/fileset_test.go @@ -0,0 +1,57 @@ +package git + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRecusiveListFile(t *testing.T) { + // Create .gitignore and ignore .gitignore and any files in + // ignored_dir + projectDir := t.TempDir() + f3, err := os.Create(filepath.Join(projectDir, ".gitignore")) + assert.NoError(t, err) + defer f3.Close() + f3.WriteString(".gitignore\nignored_dir") + + // create config file + f4, err := os.Create(filepath.Join(projectDir, "databricks.yml")) + assert.NoError(t, err) + defer f4.Close() + + // config file is returned + // .gitignore is not because we explictly ignore it in .gitignore + fileSet := NewFileSet(projectDir) + files, err := fileSet.RecursiveListFiles(projectDir) + assert.NoError(t, err) + assert.Len(t, files, 1) + assert.Equal(t, "databricks.yml", files[0].Relative) + + // Check that newly added files not in .gitignore + // are being tracked + dir1 := filepath.Join(projectDir, "a", "b", "c") + dir2 := filepath.Join(projectDir, "ignored_dir", "e") + err = os.MkdirAll(dir2, 0o755) + assert.NoError(t, err) + err = os.MkdirAll(dir1, 0o755) + assert.NoError(t, err) + f1, err := os.Create(filepath.Join(projectDir, "a/b/c/hello.txt")) + assert.NoError(t, err) + defer f1.Close() + f2, err := os.Create(filepath.Join(projectDir, "ignored_dir/e/world.txt")) + assert.NoError(t, err) + defer f2.Close() + + files, err = fileSet.RecursiveListFiles(projectDir) + assert.NoError(t, err) + assert.Len(t, files, 2) + var fileNames []string + for _, v := range files { + fileNames = append(fileNames, v.Relative) + } + assert.Contains(t, fileNames, "databricks.yml") + assert.Contains(t, fileNames, "a/b/c/hello.txt") +} diff --git a/git/git.go b/git/git.go index 75d5d2c7..9d800b91 100644 --- a/git/git.go +++ b/git/git.go @@ -1,6 +1,7 @@ package git import ( + "context" "fmt" "net/url" "os" @@ -8,6 +9,8 @@ import ( "strings" "github.com/databricks/bricks/folders" + "github.com/databricks/bricks/utilities" + "github.com/databricks/databricks-sdk-go/workspaces" giturls "github.com/whilp/git-urls" "gopkg.in/ini.v1" ) @@ -78,3 +81,18 @@ func RepositoryName() (string, error) { base := path.Base(origin.Path) return strings.TrimSuffix(base, ".git"), nil } + +func RepoExists(remotePath string, ctx context.Context, wsc *workspaces.WorkspacesClient) (bool, error) { + repos, err := utilities.GetAllRepos(ctx, wsc, remotePath) + if err != nil { + return false, fmt.Errorf("could not get repos: %s", err) + } + foundRepo := false + for _, repo := range repos { + if repo.Path == remotePath { + foundRepo = true + break + } + } + return foundRepo, nil +} diff --git a/internal/helpers.go b/internal/helpers.go new file mode 100644 index 00000000..7ca9f3a6 --- /dev/null +++ b/internal/helpers.go @@ -0,0 +1,35 @@ +package internal + +import ( + "fmt" + "math/rand" + "os" + "strings" + "testing" + "time" +) + +const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + +// GetEnvOrSkipTest proceeds with test only with that env variable +func GetEnvOrSkipTest(t *testing.T, name string) string { + value := os.Getenv(name) + if value == "" { + t.Skipf("Environment variable %s is missing", name) + } + return value +} + +// RandomName gives random name with optional prefix. e.g. qa.RandomName("tf-") +func RandomName(prefix ...string) string { + rand.Seed(time.Now().UnixNano()) + randLen := 12 + b := make([]byte, randLen) + for i := range b { + b[i] = charset[rand.Intn(randLen)] + } + if len(prefix) > 0 { + return fmt.Sprintf("%s%s", strings.Join(prefix, ""), b) + } + return string(b) +} diff --git a/internal/sync_test.go b/internal/sync_test.go new file mode 100644 index 00000000..7eec9f74 --- /dev/null +++ b/internal/sync_test.go @@ -0,0 +1,130 @@ +package internal + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "testing" + "time" + + "github.com/databricks/databricks-sdk-go/service/repos" + "github.com/databricks/databricks-sdk-go/service/workspace" + "github.com/databricks/databricks-sdk-go/workspaces" + "github.com/stretchr/testify/assert" +) + +func TestAccSync(t *testing.T) { + t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV")) + + wsc := workspaces.New() + ctx := context.Background() + me, err := wsc.CurrentUser.Me(ctx) + assert.NoError(t, err) + repoUrl := "https://github.com/shreyas-goenka/empty-repo.git" + repoPath := fmt.Sprintf("/Repos/%s/%s", me.UserName, RandomName("empty-repo-sync-integration-")) + + repoInfo, err := wsc.Repos.Create(ctx, repos.CreateRepo{ + Path: repoPath, + Url: repoUrl, + Provider: "gitHub", + }) + assert.NoError(t, err) + + t.Cleanup(func() { + err := wsc.Repos.DeleteByRepoId(ctx, fmt.Sprint(repoInfo.Id)) + assert.NoError(t, err) + }) + + // clone public remote repo + tempDir := t.TempDir() + cmd := exec.Command("git", "clone", repoUrl) + cmd.Dir = tempDir + err = cmd.Run() + assert.NoError(t, err) + + // Initialize the databrick.yml config + projectDir := filepath.Join(tempDir, "empty-repo") + content := []byte("name: test-project\nprofile: DEFAULT") + f, err := os.Create(filepath.Join(projectDir, "databricks.yml")) + assert.NoError(t, err) + defer f.Close() + _, err = f.Write(content) + assert.NoError(t, err) + + // start bricks sync process + cmd = exec.Command("bricks", "sync", "--remote-path", repoPath) + cmd.Dir = projectDir + err = cmd.Start() + assert.NoError(t, err) + t.Cleanup(func() { + cmd.Process.Kill() + }) + + // First upload assertion + assert.Eventually(t, func() bool { + repoContent, err := wsc.Workspace.List(ctx, workspace.ListRequest{ + Path: repoPath, + }) + assert.NoError(t, err) + return len(repoContent.Objects) == 2 + }, 30*time.Second, time.Second) + repoContent, err := wsc.Workspace.List(ctx, workspace.ListRequest{ + Path: repoPath, + }) + assert.NoError(t, err) + var files1 []string + for _, v := range repoContent.Objects { + files1 = append(files1, filepath.Base(v.Path)) + } + assert.Len(t, files1, 2) + assert.Contains(t, files1, "databricks.yml") + assert.Contains(t, files1, ".gitkeep") + + // Create new files and assert + os.Create(filepath.Join(projectDir, "hello.txt")) + os.Create(filepath.Join(projectDir, "world.txt")) + assert.Eventually(t, func() bool { + repoContent, err := wsc.Workspace.List(ctx, workspace.ListRequest{ + Path: repoPath, + }) + assert.NoError(t, err) + return len(repoContent.Objects) == 4 + }, 30*time.Second, time.Second) + repoContent, err = wsc.Workspace.List(ctx, workspace.ListRequest{ + Path: repoPath, + }) + assert.NoError(t, err) + var files2 []string + for _, v := range repoContent.Objects { + files2 = append(files2, filepath.Base(v.Path)) + } + assert.Len(t, files2, 4) + assert.Contains(t, files2, "databricks.yml") + assert.Contains(t, files2, ".gitkeep") + assert.Contains(t, files2, "hello.txt") + assert.Contains(t, files2, "world.txt") + + // delete a file and assert + os.Remove(filepath.Join(projectDir, "hello.txt")) + assert.Eventually(t, func() bool { + repoContent, err := wsc.Workspace.List(ctx, workspace.ListRequest{ + Path: repoPath, + }) + assert.NoError(t, err) + return len(repoContent.Objects) == 3 + }, 30*time.Second, time.Second) + repoContent, err = wsc.Workspace.List(ctx, workspace.ListRequest{ + Path: repoPath, + }) + assert.NoError(t, err) + var files3 []string + for _, v := range repoContent.Objects { + files3 = append(files3, filepath.Base(v.Path)) + } + assert.Len(t, files3, 3) + assert.Contains(t, files3, "databricks.yml") + assert.Contains(t, files3, ".gitkeep") + assert.Contains(t, files3, "world.txt") +} diff --git a/project/project.go b/project/project.go index bce9317a..f1e94f56 100644 --- a/project/project.go +++ b/project/project.go @@ -51,19 +51,18 @@ func (i *inner) WorkspacesClient() *workspaces.WorkspacesClient { return i.wsc } -// We can replace this with go sdk once https://github.com/databricks/databricks-sdk-go/issues/56 is fixed -func (i *inner) Me() *scim.User { +func (i *inner) Me() (*scim.User, error) { i.mu.Lock() defer i.mu.Unlock() if i.me != nil { - return i.me + return i.me, nil } me, err := i.wsc.CurrentUser.Me(context.Background()) if err != nil { - panic(err) + return nil, err } i.me = me - return me + return me, nil } func (i *inner) DeploymentIsolationPrefix() string { @@ -71,7 +70,10 @@ func (i *inner) DeploymentIsolationPrefix() string { return i.project.Name } if i.project.Isolation == Soft { - me := i.Me() + me, err := i.Me() + if err != nil { + panic(err) + } return fmt.Sprintf("%s/%s", i.project.Name, me.UserName) } panic(fmt.Errorf("unknow project isolation: %s", i.project.Isolation))