diff --git a/cmd/sync/sync.go b/cmd/sync/sync.go index 798a5681..62d834e3 100644 --- a/cmd/sync/sync.go +++ b/cmd/sync/sync.go @@ -120,14 +120,26 @@ var syncCmd = &cobra.Command{ return err } - s := sync.Sync{ - LocalPath: prj.Root(), - RemotePath: *remotePath, - PersistSnapshot: *persistSnapshot, - PollInterval: *interval, + cacheDir, err := prj.CacheDir() + if err != nil { + return err } - return s.RunWatchdog(ctx, wsc) + opts := sync.SyncOptions{ + LocalPath: prj.Root(), + RemotePath: *remotePath, + PersistSnapshot: *persistSnapshot, + SnapshotBasePath: cacheDir, + PollInterval: *interval, + WorkspaceClient: wsc, + } + + s, err := sync.New(opts) + if err != nil { + return err + } + + return s.RunWatchdog(ctx) }, } diff --git a/libs/sync/snapshot.go b/libs/sync/snapshot.go index 54a7e29a..27f12b9b 100644 --- a/libs/sync/snapshot.go +++ b/libs/sync/snapshot.go @@ -16,7 +16,6 @@ import ( "encoding/hex" "github.com/databricks/bricks/git" - "github.com/databricks/bricks/project" ) // Bump it up every time a potentially breaking change is made to the snapshot schema @@ -35,6 +34,13 @@ const LatestSnapshotVersion = "v1" // local files are being synced to will make bricks cli switch to a different // snapshot for persisting/loading sync state type Snapshot struct { + // Path where this snapshot was loaded from and will be saved to. + // Intentionally not part of the snapshot state because it may be moved by the user. + SnapshotPath string `json:"-"` + + // New indicates if this is a fresh snapshot or if it was loaded from disk. + New bool `json:"-"` + // version for snapshot schema. Only snapshots matching the latest snapshot // schema version are used and older ones are invalidated (by deleting them) Version string `json:"version"` @@ -76,52 +82,39 @@ func GetFileName(host, remotePath string) string { // Compute path of the snapshot file on the local machine // The file name for unique for a tuple of (host, remotePath) // precisely it's the first 16 characters of md5(concat(host, remotePath)) -func (s *Snapshot) getPath(ctx context.Context) (string, error) { - prj := project.Get(ctx) - cacheDir, err := prj.CacheDir() - if err != nil { - return "", err - } - snapshotDir := filepath.Join(cacheDir, syncSnapshotDirName) +func SnapshotPath(opts *SyncOptions) (string, error) { + snapshotDir := filepath.Join(opts.SnapshotBasePath, syncSnapshotDirName) if _, err := os.Stat(snapshotDir); os.IsNotExist(err) { err = os.Mkdir(snapshotDir, os.ModeDir|os.ModePerm) if err != nil { return "", fmt.Errorf("failed to create config directory: %s", err) } } - fileName := GetFileName(s.Host, s.RemotePath) + fileName := GetFileName(opts.Host, opts.RemotePath) return filepath.Join(snapshotDir, fileName), nil } -func newSnapshot(ctx context.Context, remotePath string) (*Snapshot, error) { - prj := project.Get(ctx) - - // Get host this snapshot is for - wsc := prj.WorkspacesClient() - - // TODO: The host may be late-initialized in certain Azure setups where we - // specify the workspace by its resource ID. tracked in: https://databricks.atlassian.net/browse/DECO-194 - host := wsc.Config.Host - if host == "" { - return nil, fmt.Errorf("failed to resolve host for snapshot") +func newSnapshot(opts *SyncOptions) (*Snapshot, error) { + path, err := SnapshotPath(opts) + if err != nil { + return nil, err } return &Snapshot{ + SnapshotPath: path, + New: true, + Version: LatestSnapshotVersion, - Host: host, - RemotePath: remotePath, + Host: opts.Host, + RemotePath: opts.RemotePath, LastUpdatedTimes: make(map[string]time.Time), LocalToRemoteNames: make(map[string]string), RemoteToLocalNames: make(map[string]string), }, nil } -func (s *Snapshot) storeSnapshot(ctx context.Context) error { - snapshotPath, err := s.getPath(ctx) - if err != nil { - return err - } - f, err := os.OpenFile(snapshotPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) +func (s *Snapshot) Save(ctx context.Context) error { + f, err := os.OpenFile(s.SnapshotPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { return fmt.Errorf("failed to create/open persisted sync snapshot file: %s", err) } @@ -139,34 +132,42 @@ func (s *Snapshot) storeSnapshot(ctx context.Context) error { return nil } -func (s *Snapshot) loadSnapshot(ctx context.Context) error { - snapshotPath, err := s.getPath(ctx) +func loadOrNewSnapshot(opts *SyncOptions) (*Snapshot, error) { + snapshot, err := newSnapshot(opts) if err != nil { - return err - } - // Snapshot file not found. We do not load anything - if _, err := os.Stat(snapshotPath); os.IsNotExist(err) { - return nil + return nil, err } - snapshotCopy := Snapshot{} + // Snapshot file not found. We return the new copy. + if _, err := os.Stat(snapshot.SnapshotPath); os.IsNotExist(err) { + return snapshot, nil + } - bytes, err := os.ReadFile(snapshotPath) + bytes, err := os.ReadFile(snapshot.SnapshotPath) if err != nil { - return fmt.Errorf("failed to read sync snapshot from disk: %s", err) + return nil, fmt.Errorf("failed to read sync snapshot from disk: %s", err) } - err = json.Unmarshal(bytes, &snapshotCopy) + + var fromDisk Snapshot + err = json.Unmarshal(bytes, &fromDisk) if err != nil { - return fmt.Errorf("failed to json unmarshal persisted snapshot: %s", err) + return nil, fmt.Errorf("failed to json unmarshal persisted snapshot: %s", err) } + // invalidate old snapshot with schema versions - if snapshotCopy.Version != LatestSnapshotVersion { - - log.Printf("Did not load existing snapshot because its version is %s while the latest version is %s", s.Version, LatestSnapshotVersion) - return nil + if fromDisk.Version != LatestSnapshotVersion { + log.Printf("Did not load existing snapshot because its version is %s while the latest version is %s", snapshot.Version, LatestSnapshotVersion) + return newSnapshot(opts) } - *s = snapshotCopy - return nil + + // unmarshal again over the existing snapshot instance + err = json.Unmarshal(bytes, &snapshot) + if err != nil { + return nil, fmt.Errorf("failed to json unmarshal persisted snapshot: %s", err) + } + + snapshot.New = false + return snapshot, nil } func (d diff) IsEmpty() bool { diff --git a/libs/sync/snapshot_test.go b/libs/sync/snapshot_test.go index 933a1b2b..f02fde04 100644 --- a/libs/sync/snapshot_test.go +++ b/libs/sync/snapshot_test.go @@ -1,7 +1,6 @@ package sync import ( - "context" "fmt" "os" "path/filepath" @@ -10,8 +9,8 @@ import ( "github.com/databricks/bricks/git" "github.com/databricks/bricks/libs/testfile" - "github.com/databricks/bricks/project" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func assertKeysOfMap(t *testing.T, m map[string]time.Time, expectedKeys []string) { @@ -219,35 +218,25 @@ func TestErrorWhenIdenticalRemoteName(t *testing.T) { assert.ErrorContains(t, err, "both foo and foo.py point to the same remote file location foo. Please remove one of them from your local project") } -func TestNewSnapshotDefaults(t *testing.T) { - ctx := setupProject(t) - snapshot, err := newSnapshot(ctx, "/Repos/foo/bar") - prj := project.Get(ctx) - assert.NoError(t, err) - - assert.Equal(t, LatestSnapshotVersion, snapshot.Version) - assert.Equal(t, "/Repos/foo/bar", snapshot.RemotePath) - assert.Equal(t, prj.WorkspacesClient().Config.Host, snapshot.Host) - assert.Empty(t, snapshot.LastUpdatedTimes) - assert.Empty(t, snapshot.RemoteToLocalNames) - assert.Empty(t, snapshot.LocalToRemoteNames) -} - -func getEmptySnapshot() Snapshot { - return Snapshot{ - LastUpdatedTimes: make(map[string]time.Time), - LocalToRemoteNames: make(map[string]string), - RemoteToLocalNames: make(map[string]string), +func defaultOptions(t *testing.T) *SyncOptions { + return &SyncOptions{ + Host: "www.foobar.com", + RemotePath: "/Repos/foo/bar", + SnapshotBasePath: t.TempDir(), } } -func setupProject(t *testing.T) context.Context { - projectDir := t.TempDir() - ctx := context.TODO() - t.Setenv("DATABRICKS_HOST", "www.foobar.com") - ctx, err := project.Initialize(ctx, projectDir, "development") - assert.NoError(t, err) - return ctx +func TestNewSnapshotDefaults(t *testing.T) { + opts := defaultOptions(t) + snapshot, err := newSnapshot(opts) + require.NoError(t, err) + + assert.Equal(t, LatestSnapshotVersion, snapshot.Version) + assert.Equal(t, opts.RemotePath, snapshot.RemotePath) + assert.Equal(t, opts.Host, snapshot.Host) + assert.Empty(t, snapshot.LastUpdatedTimes) + assert.Empty(t, snapshot.RemoteToLocalNames) + assert.Empty(t, snapshot.LocalToRemoteNames) } func TestOldSnapshotInvalidation(t *testing.T) { @@ -259,21 +248,18 @@ func TestOldSnapshotInvalidation(t *testing.T) { "local_to_remote_names": {}, "remote_to_local_names": {} }` - ctx := setupProject(t) - emptySnapshot := getEmptySnapshot() - snapshotPath, err := emptySnapshot.getPath(ctx) - assert.NoError(t, err) + opts := defaultOptions(t) + snapshotPath, err := SnapshotPath(opts) + require.NoError(t, err) snapshotFile := testfile.CreateFile(t, snapshotPath) snapshotFile.Overwrite(t, oldVersionSnapshot) snapshotFile.Close(t) - assert.FileExists(t, snapshotPath) - snapshot := emptySnapshot - err = snapshot.loadSnapshot(ctx) - assert.NoError(t, err) // assert snapshot did not get loaded - assert.Equal(t, emptySnapshot, snapshot) + snapshot, err := loadOrNewSnapshot(opts) + require.NoError(t, err) + assert.True(t, snapshot.New) } func TestNoVersionSnapshotInvalidation(t *testing.T) { @@ -284,21 +270,18 @@ func TestNoVersionSnapshotInvalidation(t *testing.T) { "local_to_remote_names": {}, "remote_to_local_names": {} }` - ctx := setupProject(t) - emptySnapshot := getEmptySnapshot() - snapshotPath, err := emptySnapshot.getPath(ctx) - assert.NoError(t, err) + opts := defaultOptions(t) + snapshotPath, err := SnapshotPath(opts) + require.NoError(t, err) snapshotFile := testfile.CreateFile(t, snapshotPath) snapshotFile.Overwrite(t, noVersionSnapshot) snapshotFile.Close(t) - assert.FileExists(t, snapshotPath) - snapshot := emptySnapshot - err = snapshot.loadSnapshot(ctx) - assert.NoError(t, err) // assert snapshot did not get loaded - assert.Equal(t, emptySnapshot, snapshot) + snapshot, err := loadOrNewSnapshot(opts) + require.NoError(t, err) + assert.True(t, snapshot.New) } func TestLatestVersionSnapshotGetsLoaded(t *testing.T) { @@ -311,22 +294,17 @@ func TestLatestVersionSnapshotGetsLoaded(t *testing.T) { "remote_to_local_names": {} }`, LatestSnapshotVersion) - ctx := setupProject(t) - emptySnapshot := getEmptySnapshot() - snapshotPath, err := emptySnapshot.getPath(ctx) - assert.NoError(t, err) - + opts := defaultOptions(t) + snapshotPath, err := SnapshotPath(opts) + require.NoError(t, err) snapshotFile := testfile.CreateFile(t, snapshotPath) snapshotFile.Overwrite(t, latestVersionSnapshot) snapshotFile.Close(t) - assert.FileExists(t, snapshotPath) - snapshot := emptySnapshot - err = snapshot.loadSnapshot(ctx) - assert.NoError(t, err) - // assert snapshot gets loaded - assert.NotEqual(t, emptySnapshot, snapshot) + snapshot, err := loadOrNewSnapshot(opts) + require.NoError(t, err) + assert.False(t, snapshot.New) assert.Equal(t, LatestSnapshotVersion, snapshot.Version) assert.Equal(t, "www.foobar.com", snapshot.Host) assert.Equal(t, "/Repos/foo/bar", snapshot.RemotePath) diff --git a/libs/sync/sync.go b/libs/sync/sync.go index 9a247254..55055f8a 100644 --- a/libs/sync/sync.go +++ b/libs/sync/sync.go @@ -2,25 +2,60 @@ package sync import ( "context" + "fmt" "time" + "github.com/databricks/bricks/git" "github.com/databricks/bricks/libs/sync/repofiles" "github.com/databricks/databricks-sdk-go" ) -type Sync struct { +type SyncOptions struct { LocalPath string RemotePath string PersistSnapshot bool + SnapshotBasePath string + PollInterval time.Duration + + WorkspaceClient *databricks.WorkspaceClient + + Host string +} + +type Sync struct { + *SyncOptions + + fileSet *git.FileSet +} + +// New initializes and returns a new [Sync] instance. +func New(opts SyncOptions) (*Sync, error) { + fileSet := git.NewFileSet(opts.LocalPath) + err := fileSet.EnsureValidGitIgnoreExists() + if err != nil { + return nil, err + } + + // TODO: The host may be late-initialized in certain Azure setups where we + // specify the workspace by its resource ID. tracked in: https://databricks.atlassian.net/browse/DECO-194 + opts.Host = opts.WorkspaceClient.Config.Host + if opts.Host == "" { + return nil, fmt.Errorf("failed to resolve host for snapshot") + } + + return &Sync{ + SyncOptions: &opts, + fileSet: fileSet, + }, nil } // RunWatchdog kicks off a polling loop to monitor local changes and synchronize // them to the remote workspace path. -func (s *Sync) RunWatchdog(ctx context.Context, wsc *databricks.WorkspaceClient) error { - repoFiles := repofiles.Create(s.RemotePath, s.LocalPath, wsc) +func (s *Sync) RunWatchdog(ctx context.Context) error { + repoFiles := repofiles.Create(s.RemotePath, s.LocalPath, s.WorkspaceClient) syncCallback := syncCallback(ctx, repoFiles) - return spawnWatchdog(ctx, s.PollInterval, syncCallback, s.RemotePath, s.PersistSnapshot) + return spawnWatchdog(ctx, syncCallback, s) } diff --git a/libs/sync/watchdog.go b/libs/sync/watchdog.go index 8e45b0a6..bf8b7717 100644 --- a/libs/sync/watchdog.go +++ b/libs/sync/watchdog.go @@ -7,7 +7,6 @@ import ( "time" "github.com/databricks/bricks/libs/sync/repofiles" - "github.com/databricks/bricks/project" "golang.org/x/sync/errgroup" ) @@ -17,7 +16,7 @@ type watchdog struct { wg sync.WaitGroup failure error // data race? make channel? - persistSnapshot bool + sync *Sync } // See https://docs.databricks.com/resources/limits.html#limits-api-rate-limits for per api @@ -70,16 +69,14 @@ func syncCallback(ctx context.Context, repoFiles *repofiles.RepoFiles) func(loca } func spawnWatchdog(ctx context.Context, - interval time.Duration, applyDiff func(diff) error, - remotePath string, - persistSnapshot bool) error { + sync *Sync) error { w := &watchdog{ - ticker: time.NewTicker(interval), - persistSnapshot: persistSnapshot, + ticker: time.NewTicker(sync.PollInterval), + sync: sync, } w.wg.Add(1) - go w.main(ctx, applyDiff, remotePath) + go w.main(ctx, applyDiff, sync.RemotePath) w.wg.Wait() return w.failure } @@ -88,28 +85,27 @@ func spawnWatchdog(ctx context.Context, // https://github.com/gorakhargosh/watchdog/blob/master/src/watchdog/observers/kqueue.py#L394-L418 func (w *watchdog) main(ctx context.Context, applyDiff func(diff) error, remotePath string) { defer w.wg.Done() - snapshot, err := newSnapshot(ctx, remotePath) + snapshot, err := newSnapshot(w.sync.SyncOptions) if err != nil { log.Printf("[ERROR] cannot create snapshot: %s", err) w.failure = err return } - if w.persistSnapshot { - err := snapshot.loadSnapshot(ctx) + if w.sync.PersistSnapshot { + snapshot, err = loadOrNewSnapshot(w.sync.SyncOptions) if err != nil { log.Printf("[ERROR] cannot load snapshot: %s", err) w.failure = err return } } - prj := project.Get(ctx) var onlyOnceInitLog sync.Once for { select { case <-ctx.Done(): return case <-w.ticker.C: - all, err := prj.GetFileSet().All() + all, err := w.sync.fileSet.All() if err != nil { log.Printf("[ERROR] cannot list files: %s", err) w.failure = err @@ -132,8 +128,8 @@ func (w *watchdog) main(ctx context.Context, applyDiff func(diff) error, remoteP w.failure = err return } - if w.persistSnapshot { - err = snapshot.storeSnapshot(ctx) + if w.sync.PersistSnapshot { + err = snapshot.Save(ctx) if err != nil { log.Printf("[ERROR] cannot store snapshot: %s", err) w.failure = err