diff --git a/bundle/config/mutator/process_root_includes_test.go b/bundle/config/mutator/process_root_includes_test.go index 88a6c743..645eb89a 100644 --- a/bundle/config/mutator/process_root_includes_test.go +++ b/bundle/config/mutator/process_root_includes_test.go @@ -4,7 +4,6 @@ import ( "context" "os" "path" - "path/filepath" "runtime" "strings" "testing" @@ -13,16 +12,11 @@ import ( "github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config/mutator" "github.com/databricks/cli/bundle/env" + "github.com/databricks/cli/internal/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func touch(t *testing.T, path, file string) { - f, err := os.Create(filepath.Join(path, file)) - require.NoError(t, err) - f.Close() -} - func TestProcessRootIncludesEmpty(t *testing.T) { b := &bundle.Bundle{ Config: config.Root{ @@ -64,9 +58,9 @@ func TestProcessRootIncludesSingleGlob(t *testing.T) { }, } - touch(t, b.Config.Path, "databricks.yml") - touch(t, b.Config.Path, "a.yml") - touch(t, b.Config.Path, "b.yml") + testutil.Touch(t, b.Config.Path, "databricks.yml") + testutil.Touch(t, b.Config.Path, "a.yml") + testutil.Touch(t, b.Config.Path, "b.yml") err := bundle.Apply(context.Background(), b, mutator.ProcessRootIncludes()) require.NoError(t, err) @@ -85,8 +79,8 @@ func TestProcessRootIncludesMultiGlob(t *testing.T) { }, } - touch(t, b.Config.Path, "a1.yml") - touch(t, b.Config.Path, "b1.yml") + testutil.Touch(t, b.Config.Path, "a1.yml") + testutil.Touch(t, b.Config.Path, "b1.yml") err := bundle.Apply(context.Background(), b, mutator.ProcessRootIncludes()) require.NoError(t, err) @@ -105,7 +99,7 @@ func TestProcessRootIncludesRemoveDups(t *testing.T) { }, } - touch(t, b.Config.Path, "a.yml") + testutil.Touch(t, b.Config.Path, "a.yml") err := bundle.Apply(context.Background(), b, mutator.ProcessRootIncludes()) require.NoError(t, err) @@ -129,7 +123,7 @@ func TestProcessRootIncludesNotExists(t *testing.T) { func TestProcessRootIncludesExtrasFromEnvVar(t *testing.T) { rootPath := t.TempDir() testYamlName := "extra_include_path.yml" - touch(t, rootPath, testYamlName) + testutil.Touch(t, rootPath, testYamlName) t.Setenv(env.IncludesVariable, path.Join(rootPath, testYamlName)) b := &bundle.Bundle{ @@ -146,7 +140,7 @@ func TestProcessRootIncludesExtrasFromEnvVar(t *testing.T) { func TestProcessRootIncludesDedupExtrasFromEnvVar(t *testing.T) { rootPath := t.TempDir() testYamlName := "extra_include_path.yml" - touch(t, rootPath, testYamlName) + testutil.Touch(t, rootPath, testYamlName) t.Setenv(env.IncludesVariable, strings.Join( []string{ path.Join(rootPath, testYamlName), diff --git a/bundle/deploy/filer.go b/bundle/deploy/filer.go new file mode 100644 index 00000000..c0fd839e --- /dev/null +++ b/bundle/deploy/filer.go @@ -0,0 +1,14 @@ +package deploy + +import ( + "github.com/databricks/cli/bundle" + "github.com/databricks/cli/libs/filer" +) + +// FilerFactory is a function that returns a filer.Filer. +type FilerFactory func(b *bundle.Bundle) (filer.Filer, error) + +// StateFiler returns a filer.Filer that can be used to read/write state files. +func StateFiler(b *bundle.Bundle) (filer.Filer, error) { + return filer.NewWorkspaceFilesClient(b.WorkspaceClient(), b.Config.Workspace.StatePath) +} diff --git a/bundle/deploy/files/delete.go b/bundle/deploy/files/delete.go index 9f7ad4d4..8585ec3c 100644 --- a/bundle/deploy/files/delete.go +++ b/bundle/deploy/files/delete.go @@ -45,7 +45,7 @@ func (m *delete) Apply(ctx context.Context, b *bundle.Bundle) error { } // Clean up sync snapshot file - sync, err := getSync(ctx, b) + sync, err := GetSync(ctx, b) if err != nil { return err } diff --git a/bundle/deploy/files/sync.go b/bundle/deploy/files/sync.go index 148a63ff..8de80c22 100644 --- a/bundle/deploy/files/sync.go +++ b/bundle/deploy/files/sync.go @@ -8,7 +8,15 @@ import ( "github.com/databricks/cli/libs/sync" ) -func getSync(ctx context.Context, b *bundle.Bundle) (*sync.Sync, error) { +func GetSync(ctx context.Context, b *bundle.Bundle) (*sync.Sync, error) { + opts, err := GetSyncOptions(ctx, b) + if err != nil { + return nil, fmt.Errorf("cannot get sync options: %w", err) + } + return sync.New(ctx, *opts) +} + +func GetSyncOptions(ctx context.Context, b *bundle.Bundle) (*sync.SyncOptions, error) { cacheDir, err := b.CacheDir(ctx) if err != nil { return nil, fmt.Errorf("cannot get bundle cache directory: %w", err) @@ -19,17 +27,22 @@ func getSync(ctx context.Context, b *bundle.Bundle) (*sync.Sync, error) { return nil, fmt.Errorf("cannot get list of sync includes: %w", err) } - opts := sync.SyncOptions{ + opts := &sync.SyncOptions{ LocalPath: b.Config.Path, RemotePath: b.Config.Workspace.FilePath, Include: includes, Exclude: b.Config.Sync.Exclude, + Host: b.WorkspaceClient().Config.Host, - Full: false, - CurrentUser: b.Config.Workspace.CurrentUser.User, + Full: false, SnapshotBasePath: cacheDir, WorkspaceClient: b.WorkspaceClient(), } - return sync.New(ctx, opts) + + if b.Config.Workspace.CurrentUser != nil { + opts.CurrentUser = b.Config.Workspace.CurrentUser.User + } + + return opts, nil } diff --git a/bundle/deploy/files/upload.go b/bundle/deploy/files/upload.go index 26d1ef4b..4da41e20 100644 --- a/bundle/deploy/files/upload.go +++ b/bundle/deploy/files/upload.go @@ -17,7 +17,7 @@ func (m *upload) Name() string { func (m *upload) Apply(ctx context.Context, b *bundle.Bundle) error { cmdio.LogString(ctx, fmt.Sprintf("Uploading bundle files to %s...", b.Config.Workspace.FilePath)) - sync, err := getSync(ctx, b) + sync, err := GetSync(ctx, b) if err != nil { return err } diff --git a/bundle/deploy/state.go b/bundle/deploy/state.go new file mode 100644 index 00000000..ffcadc9d --- /dev/null +++ b/bundle/deploy/state.go @@ -0,0 +1,174 @@ +package deploy + +import ( + "context" + "encoding/json" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "time" + + "github.com/databricks/cli/bundle" + "github.com/databricks/cli/libs/fileset" +) + +const DeploymentStateFileName = "deployment.json" +const DeploymentStateVersion = 1 + +type File struct { + LocalPath string `json:"local_path"` + + // If true, this file is a notebook. + // This property must be persisted because notebooks are stripped of their extension. + // If the local file is no longer present, we need to know what to remove on the workspace side. + IsNotebook bool `json:"is_notebook"` +} + +type Filelist []File + +type DeploymentState struct { + // Version is the version of the deployment state. + // To be incremented when the schema changes. + Version int64 `json:"version"` + + // Seq is the sequence number of the deployment state. + // This number is incremented on every deployment. + // It is used to detect if the deployment state is stale. + Seq int64 `json:"seq"` + + // CliVersion is the version of the CLI which created the deployment state. + CliVersion string `json:"cli_version"` + + // Timestamp is the time when the deployment state was created. + Timestamp time.Time `json:"timestamp"` + + // Files is a list of files which has been deployed as part of this deployment. + Files Filelist `json:"files"` +} + +// We use this entry type as a proxy to fs.DirEntry. +// When we construct sync snapshot from deployment state, +// we use a fileset.File which embeds fs.DirEntry as the DirEntry field. +// Because we can't marshal/unmarshal fs.DirEntry directly, instead when we unmarshal +// the deployment state, we use this entry type to represent the fs.DirEntry in fileset.File instance. +type entry struct { + path string + info fs.FileInfo +} + +func newEntry(path string) *entry { + info, err := os.Stat(path) + if err != nil { + return &entry{path, nil} + } + + return &entry{path, info} +} + +func (e *entry) Name() string { + return filepath.Base(e.path) +} + +func (e *entry) IsDir() bool { + // If the entry is nil, it is a non-existent file so return false. + if e.info == nil { + return false + } + return e.info.IsDir() +} + +func (e *entry) Type() fs.FileMode { + // If the entry is nil, it is a non-existent file so return 0. + if e.info == nil { + return 0 + } + return e.info.Mode() +} + +func (e *entry) Info() (fs.FileInfo, error) { + if e.info == nil { + return nil, fmt.Errorf("no info available") + } + return e.info, nil +} + +func FromSlice(files []fileset.File) (Filelist, error) { + var f Filelist + for k := range files { + file := &files[k] + isNotebook, err := file.IsNotebook() + if err != nil { + return nil, err + } + f = append(f, File{ + LocalPath: file.Relative, + IsNotebook: isNotebook, + }) + } + return f, nil +} + +func (f Filelist) ToSlice(basePath string) []fileset.File { + var files []fileset.File + for _, file := range f { + absPath := filepath.Join(basePath, file.LocalPath) + if file.IsNotebook { + files = append(files, fileset.NewNotebookFile(newEntry(absPath), absPath, file.LocalPath)) + } else { + files = append(files, fileset.NewSourceFile(newEntry(absPath), absPath, file.LocalPath)) + } + } + return files +} + +func isLocalStateStale(local io.Reader, remote io.Reader) bool { + localState, err := loadState(local) + if err != nil { + return true + } + + remoteState, err := loadState(remote) + if err != nil { + return false + } + + return localState.Seq < remoteState.Seq +} + +func validateRemoteStateCompatibility(remote io.Reader) error { + state, err := loadState(remote) + if err != nil { + return err + } + + // If the remote state version is greater than the CLI version, we can't proceed. + if state.Version > DeploymentStateVersion { + return fmt.Errorf("remote deployment state is incompatible with the current version of the CLI, please upgrade to at least %s", state.CliVersion) + } + + return nil +} + +func loadState(r io.Reader) (*DeploymentState, error) { + content, err := io.ReadAll(r) + if err != nil { + return nil, err + } + var s DeploymentState + err = json.Unmarshal(content, &s) + if err != nil { + return nil, err + } + + return &s, nil +} + +func getPathToStateFile(ctx context.Context, b *bundle.Bundle) (string, error) { + cacheDir, err := b.CacheDir(ctx) + if err != nil { + return "", fmt.Errorf("cannot get bundle cache directory: %w", err) + } + return filepath.Join(cacheDir, DeploymentStateFileName), nil +} diff --git a/bundle/deploy/state_pull.go b/bundle/deploy/state_pull.go new file mode 100644 index 00000000..089a870c --- /dev/null +++ b/bundle/deploy/state_pull.go @@ -0,0 +1,126 @@ +package deploy + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "io/fs" + "os" + + "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/deploy/files" + "github.com/databricks/cli/libs/filer" + "github.com/databricks/cli/libs/log" + "github.com/databricks/cli/libs/sync" +) + +type statePull struct { + filerFactory FilerFactory +} + +func (s *statePull) Apply(ctx context.Context, b *bundle.Bundle) error { + f, err := s.filerFactory(b) + if err != nil { + return err + } + + // Download deployment state file from filer to local cache directory. + log.Infof(ctx, "Opening remote deployment state file") + remote, err := s.remoteState(ctx, f) + if err != nil { + log.Infof(ctx, "Unable to open remote deployment state file: %s", err) + return err + } + if remote == nil { + log.Infof(ctx, "Remote deployment state file does not exist") + return nil + } + + statePath, err := getPathToStateFile(ctx, b) + if err != nil { + return err + } + + local, err := os.OpenFile(statePath, os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + return err + } + defer local.Close() + + data := remote.Bytes() + err = validateRemoteStateCompatibility(bytes.NewReader(data)) + if err != nil { + return err + } + + if !isLocalStateStale(local, bytes.NewReader(data)) { + log.Infof(ctx, "Local deployment state is the same or newer, ignoring remote state") + return nil + } + + // Truncating the file before writing + local.Truncate(0) + local.Seek(0, 0) + + // Write file to disk. + log.Infof(ctx, "Writing remote deployment state file to local cache directory") + _, err = io.Copy(local, bytes.NewReader(data)) + if err != nil { + return err + } + + var state DeploymentState + err = json.Unmarshal(data, &state) + if err != nil { + return err + } + + // Create a new snapshot based on the deployment state file. + opts, err := files.GetSyncOptions(ctx, b) + if err != nil { + return err + } + + log.Infof(ctx, "Creating new snapshot") + snapshot, err := sync.NewSnapshot(state.Files.ToSlice(b.Config.Path), opts) + if err != nil { + return err + } + + // Persist the snapshot to disk. + log.Infof(ctx, "Persisting snapshot to disk") + return snapshot.Save(ctx) +} + +func (s *statePull) remoteState(ctx context.Context, f filer.Filer) (*bytes.Buffer, error) { + // Download deployment state file from filer to local cache directory. + remote, err := f.Read(ctx, DeploymentStateFileName) + if err != nil { + // On first deploy this file doesn't yet exist. + if errors.Is(err, fs.ErrNotExist) { + return nil, nil + } + return nil, err + } + + defer remote.Close() + + var buf bytes.Buffer + _, err = io.Copy(&buf, remote) + if err != nil { + return nil, err + } + + return &buf, nil +} + +func (s *statePull) Name() string { + return "deploy:state-pull" +} + +// StatePull returns a mutator that pulls the deployment state from the Databricks workspace +func StatePull() bundle.Mutator { + return &statePull{StateFiler} +} diff --git a/bundle/deploy/state_pull_test.go b/bundle/deploy/state_pull_test.go new file mode 100644 index 00000000..50eb9091 --- /dev/null +++ b/bundle/deploy/state_pull_test.go @@ -0,0 +1,457 @@ +package deploy + +import ( + "bytes" + "context" + "encoding/json" + "io" + "os" + "path/filepath" + "testing" + + "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/deploy/files" + mockfiler "github.com/databricks/cli/internal/mocks/libs/filer" + "github.com/databricks/cli/internal/testutil" + "github.com/databricks/cli/libs/filer" + "github.com/databricks/cli/libs/sync" + "github.com/databricks/databricks-sdk-go/service/iam" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type snapshortStateExpectations struct { + localToRemoteNames map[string]string + remoteToLocalNames map[string]string +} + +type statePullExpectations struct { + seq int + filesInDevelopmentState []File + snapshotState *snapshortStateExpectations +} + +type statePullOpts struct { + files []File + seq int + localFiles []string + localNotebooks []string + expects statePullExpectations + withExistingSnapshot bool + localState *DeploymentState +} + +func testStatePull(t *testing.T, opts statePullOpts) { + s := &statePull{func(b *bundle.Bundle) (filer.Filer, error) { + f := mockfiler.NewMockFiler(t) + + deploymentStateData, err := json.Marshal(DeploymentState{ + Version: DeploymentStateVersion, + Seq: int64(opts.seq), + Files: opts.files, + }) + require.NoError(t, err) + + f.EXPECT().Read(mock.Anything, DeploymentStateFileName).Return(io.NopCloser(bytes.NewReader(deploymentStateData)), nil) + + return f, nil + }} + + b := &bundle.Bundle{ + Config: config.Root{ + Path: t.TempDir(), + Bundle: config.Bundle{ + Target: "default", + }, + Workspace: config.Workspace{ + StatePath: "/state", + CurrentUser: &config.User{ + User: &iam.User{ + UserName: "test-user", + }, + }, + }, + }, + } + ctx := context.Background() + + for _, file := range opts.localFiles { + testutil.Touch(t, filepath.Join(b.Config.Path, "bar"), file) + } + + for _, file := range opts.localNotebooks { + testutil.TouchNotebook(t, filepath.Join(b.Config.Path, "bar"), file) + } + + if opts.withExistingSnapshot { + opts, err := files.GetSyncOptions(ctx, b) + require.NoError(t, err) + + snapshotPath, err := sync.SnapshotPath(opts) + require.NoError(t, err) + + err = os.WriteFile(snapshotPath, []byte("snapshot"), 0644) + require.NoError(t, err) + } + + if opts.localState != nil { + statePath, err := getPathToStateFile(ctx, b) + require.NoError(t, err) + + data, err := json.Marshal(opts.localState) + require.NoError(t, err) + + err = os.WriteFile(statePath, data, 0644) + require.NoError(t, err) + } + + err := bundle.Apply(ctx, b, s) + require.NoError(t, err) + + // Check that deployment state was written + statePath, err := getPathToStateFile(ctx, b) + require.NoError(t, err) + + data, err := os.ReadFile(statePath) + require.NoError(t, err) + + var state DeploymentState + err = json.Unmarshal(data, &state) + require.NoError(t, err) + + require.Equal(t, int64(opts.expects.seq), state.Seq) + require.Len(t, state.Files, len(opts.expects.filesInDevelopmentState)) + for i, file := range opts.expects.filesInDevelopmentState { + require.Equal(t, file.LocalPath, state.Files[i].LocalPath) + } + + if opts.expects.snapshotState != nil { + syncOpts, err := files.GetSyncOptions(ctx, b) + require.NoError(t, err) + + snapshotPath, err := sync.SnapshotPath(syncOpts) + require.NoError(t, err) + + _, err = os.Stat(snapshotPath) + require.NoError(t, err) + + data, err = os.ReadFile(snapshotPath) + require.NoError(t, err) + + var snapshot sync.Snapshot + err = json.Unmarshal(data, &snapshot) + require.NoError(t, err) + + snapshotState := snapshot.SnapshotState + require.Len(t, snapshotState.LocalToRemoteNames, len(opts.expects.snapshotState.localToRemoteNames)) + for local, remote := range opts.expects.snapshotState.localToRemoteNames { + require.Equal(t, remote, snapshotState.LocalToRemoteNames[local]) + } + + require.Len(t, snapshotState.RemoteToLocalNames, len(opts.expects.snapshotState.remoteToLocalNames)) + for remote, local := range opts.expects.snapshotState.remoteToLocalNames { + require.Equal(t, local, snapshotState.RemoteToLocalNames[remote]) + } + } +} + +var stateFiles []File = []File{ + { + LocalPath: "bar/t1.py", + IsNotebook: false, + }, + { + LocalPath: "bar/t2.py", + IsNotebook: false, + }, + { + LocalPath: "bar/notebook.py", + IsNotebook: true, + }, +} + +func TestStatePull(t *testing.T) { + testStatePull(t, statePullOpts{ + seq: 1, + files: stateFiles, + localFiles: []string{"t1.py", "t2.py"}, + localNotebooks: []string{"notebook.py"}, + expects: statePullExpectations{ + seq: 1, + filesInDevelopmentState: []File{ + { + LocalPath: "bar/t1.py", + }, + { + LocalPath: "bar/t2.py", + }, + { + LocalPath: "bar/notebook.py", + }, + }, + snapshotState: &snapshortStateExpectations{ + localToRemoteNames: map[string]string{ + "bar/t1.py": "bar/t1.py", + "bar/t2.py": "bar/t2.py", + "bar/notebook.py": "bar/notebook", + }, + remoteToLocalNames: map[string]string{ + "bar/t1.py": "bar/t1.py", + "bar/t2.py": "bar/t2.py", + "bar/notebook": "bar/notebook.py", + }, + }, + }, + }) +} + +func TestStatePullSnapshotExists(t *testing.T) { + testStatePull(t, statePullOpts{ + withExistingSnapshot: true, + seq: 1, + files: stateFiles, + localFiles: []string{"t1.py", "t2.py"}, + expects: statePullExpectations{ + seq: 1, + filesInDevelopmentState: []File{ + { + LocalPath: "bar/t1.py", + }, + { + LocalPath: "bar/t2.py", + }, + { + LocalPath: "bar/notebook.py", + }, + }, + snapshotState: &snapshortStateExpectations{ + localToRemoteNames: map[string]string{ + "bar/t1.py": "bar/t1.py", + "bar/t2.py": "bar/t2.py", + "bar/notebook.py": "bar/notebook", + }, + remoteToLocalNames: map[string]string{ + "bar/t1.py": "bar/t1.py", + "bar/t2.py": "bar/t2.py", + "bar/notebook": "bar/notebook.py", + }, + }, + }, + }) +} + +func TestStatePullNoState(t *testing.T) { + s := &statePull{func(b *bundle.Bundle) (filer.Filer, error) { + f := mockfiler.NewMockFiler(t) + + f.EXPECT().Read(mock.Anything, DeploymentStateFileName).Return(nil, os.ErrNotExist) + + return f, nil + }} + + b := &bundle.Bundle{ + Config: config.Root{ + Path: t.TempDir(), + Bundle: config.Bundle{ + Target: "default", + }, + Workspace: config.Workspace{ + StatePath: "/state", + }, + }, + } + ctx := context.Background() + + err := bundle.Apply(ctx, b, s) + require.NoError(t, err) + + // Check that deployment state was not written + statePath, err := getPathToStateFile(ctx, b) + require.NoError(t, err) + + _, err = os.Stat(statePath) + require.True(t, os.IsNotExist(err)) +} + +func TestStatePullOlderState(t *testing.T) { + testStatePull(t, statePullOpts{ + seq: 1, + files: stateFiles, + localFiles: []string{"t1.py", "t2.py"}, + localNotebooks: []string{"notebook.py"}, + localState: &DeploymentState{ + Version: DeploymentStateVersion, + Seq: 2, + Files: []File{ + { + LocalPath: "bar/t1.py", + }, + }, + }, + expects: statePullExpectations{ + seq: 2, + filesInDevelopmentState: []File{ + { + LocalPath: "bar/t1.py", + }, + }, + }, + }) +} + +func TestStatePullNewerState(t *testing.T) { + testStatePull(t, statePullOpts{ + seq: 1, + files: stateFiles, + localFiles: []string{"t1.py", "t2.py"}, + localNotebooks: []string{"notebook.py"}, + localState: &DeploymentState{ + Version: DeploymentStateVersion, + Seq: 0, + Files: []File{ + { + LocalPath: "bar/t1.py", + }, + }, + }, + expects: statePullExpectations{ + seq: 1, + filesInDevelopmentState: []File{ + { + LocalPath: "bar/t1.py", + }, + { + LocalPath: "bar/t2.py", + }, + { + LocalPath: "bar/notebook.py", + }, + }, + snapshotState: &snapshortStateExpectations{ + localToRemoteNames: map[string]string{ + "bar/t1.py": "bar/t1.py", + "bar/t2.py": "bar/t2.py", + "bar/notebook.py": "bar/notebook", + }, + remoteToLocalNames: map[string]string{ + "bar/t1.py": "bar/t1.py", + "bar/t2.py": "bar/t2.py", + "bar/notebook": "bar/notebook.py", + }, + }, + }, + }) +} + +func TestStatePullAndFileIsRemovedLocally(t *testing.T) { + testStatePull(t, statePullOpts{ + seq: 1, + files: stateFiles, + localFiles: []string{"t2.py"}, // t1.py is removed locally + localNotebooks: []string{"notebook.py"}, + expects: statePullExpectations{ + seq: 1, + filesInDevelopmentState: []File{ + { + LocalPath: "bar/t1.py", + }, + { + LocalPath: "bar/t2.py", + }, + { + LocalPath: "bar/notebook.py", + }, + }, + snapshotState: &snapshortStateExpectations{ + localToRemoteNames: map[string]string{ + "bar/t1.py": "bar/t1.py", + "bar/t2.py": "bar/t2.py", + "bar/notebook.py": "bar/notebook", + }, + remoteToLocalNames: map[string]string{ + "bar/t1.py": "bar/t1.py", + "bar/t2.py": "bar/t2.py", + "bar/notebook": "bar/notebook.py", + }, + }, + }, + }) +} + +func TestStatePullAndNotebookIsRemovedLocally(t *testing.T) { + testStatePull(t, statePullOpts{ + seq: 1, + files: stateFiles, + localFiles: []string{"t1.py", "t2.py"}, + localNotebooks: []string{}, // notebook.py is removed locally + expects: statePullExpectations{ + seq: 1, + filesInDevelopmentState: []File{ + { + LocalPath: "bar/t1.py", + }, + { + LocalPath: "bar/t2.py", + }, + { + LocalPath: "bar/notebook.py", + }, + }, + snapshotState: &snapshortStateExpectations{ + localToRemoteNames: map[string]string{ + "bar/t1.py": "bar/t1.py", + "bar/t2.py": "bar/t2.py", + "bar/notebook.py": "bar/notebook", + }, + remoteToLocalNames: map[string]string{ + "bar/t1.py": "bar/t1.py", + "bar/t2.py": "bar/t2.py", + "bar/notebook": "bar/notebook.py", + }, + }, + }, + }) +} + +func TestStatePullNewerDeploymentStateVersion(t *testing.T) { + s := &statePull{func(b *bundle.Bundle) (filer.Filer, error) { + f := mockfiler.NewMockFiler(t) + + deploymentStateData, err := json.Marshal(DeploymentState{ + Version: DeploymentStateVersion + 1, + Seq: 1, + CliVersion: "1.2.3", + Files: []File{ + { + LocalPath: "bar/t1.py", + }, + { + LocalPath: "bar/t2.py", + }, + }, + }) + require.NoError(t, err) + + f.EXPECT().Read(mock.Anything, DeploymentStateFileName).Return(io.NopCloser(bytes.NewReader(deploymentStateData)), nil) + + return f, nil + }} + + b := &bundle.Bundle{ + Config: config.Root{ + Path: t.TempDir(), + Bundle: config.Bundle{ + Target: "default", + }, + Workspace: config.Workspace{ + StatePath: "/state", + }, + }, + } + ctx := context.Background() + + err := bundle.Apply(ctx, b, s) + require.Error(t, err) + require.Contains(t, err.Error(), "remote deployment state is incompatible with the current version of the CLI, please upgrade to at least 1.2.3") +} diff --git a/bundle/deploy/state_push.go b/bundle/deploy/state_push.go new file mode 100644 index 00000000..8818d0a7 --- /dev/null +++ b/bundle/deploy/state_push.go @@ -0,0 +1,49 @@ +package deploy + +import ( + "context" + "os" + + "github.com/databricks/cli/bundle" + "github.com/databricks/cli/libs/filer" + "github.com/databricks/cli/libs/log" +) + +type statePush struct { + filerFactory FilerFactory +} + +func (s *statePush) Name() string { + return "deploy:state-push" +} + +func (s *statePush) Apply(ctx context.Context, b *bundle.Bundle) error { + f, err := s.filerFactory(b) + if err != nil { + return err + } + + statePath, err := getPathToStateFile(ctx, b) + if err != nil { + return err + } + + local, err := os.Open(statePath) + if err != nil { + return err + } + defer local.Close() + + log.Infof(ctx, "Writing local deployment state file to remote state directory") + err = f.Write(ctx, DeploymentStateFileName, local, filer.CreateParentDirectories, filer.OverwriteIfExists) + if err != nil { + return err + } + + return nil +} + +// StatePush returns a mutator that pushes the deployment state file to Databricks workspace. +func StatePush() bundle.Mutator { + return &statePush{StateFiler} +} diff --git a/bundle/deploy/state_push_test.go b/bundle/deploy/state_push_test.go new file mode 100644 index 00000000..37b865ec --- /dev/null +++ b/bundle/deploy/state_push_test.go @@ -0,0 +1,82 @@ +package deploy + +import ( + "context" + "encoding/json" + "io" + "os" + "testing" + + "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/config" + mockfiler "github.com/databricks/cli/internal/mocks/libs/filer" + "github.com/databricks/cli/libs/filer" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestStatePush(t *testing.T) { + s := &statePush{func(b *bundle.Bundle) (filer.Filer, error) { + f := mockfiler.NewMockFiler(t) + + f.EXPECT().Write(mock.Anything, DeploymentStateFileName, mock.MatchedBy(func(r *os.File) bool { + bytes, err := io.ReadAll(r) + if err != nil { + return false + } + + var state DeploymentState + err = json.Unmarshal(bytes, &state) + if err != nil { + return false + } + + if state.Seq != 1 { + return false + } + + if len(state.Files) != 1 { + return false + } + + return true + }), filer.CreateParentDirectories, filer.OverwriteIfExists).Return(nil) + return f, nil + }} + + b := &bundle.Bundle{ + Config: config.Root{ + Path: t.TempDir(), + Bundle: config.Bundle{ + Target: "default", + }, + Workspace: config.Workspace{ + StatePath: "/state", + }, + }, + } + + ctx := context.Background() + + statePath, err := getPathToStateFile(ctx, b) + require.NoError(t, err) + + state := DeploymentState{ + Version: DeploymentStateVersion, + Seq: 1, + Files: []File{ + { + LocalPath: "bar/t1.py", + }, + }, + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + err = os.WriteFile(statePath, data, 0644) + require.NoError(t, err) + + err = bundle.Apply(ctx, b, s) + require.NoError(t, err) +} diff --git a/bundle/deploy/state_test.go b/bundle/deploy/state_test.go new file mode 100644 index 00000000..15bdc96b --- /dev/null +++ b/bundle/deploy/state_test.go @@ -0,0 +1,79 @@ +package deploy + +import ( + "bytes" + "encoding/json" + "path/filepath" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/databricks/cli/libs/fileset" + "github.com/stretchr/testify/require" +) + +func TestFromSlice(t *testing.T) { + tmpDir := t.TempDir() + fileset := fileset.New(tmpDir) + testutil.Touch(t, tmpDir, "test1.py") + testutil.Touch(t, tmpDir, "test2.py") + testutil.Touch(t, tmpDir, "test3.py") + + files, err := fileset.All() + require.NoError(t, err) + + f, err := FromSlice(files) + require.NoError(t, err) + require.Len(t, f, 3) + + for _, file := range f { + require.Contains(t, []string{"test1.py", "test2.py", "test3.py"}, file.LocalPath) + } +} + +func TestToSlice(t *testing.T) { + tmpDir := t.TempDir() + fileset := fileset.New(tmpDir) + testutil.Touch(t, tmpDir, "test1.py") + testutil.Touch(t, tmpDir, "test2.py") + testutil.Touch(t, tmpDir, "test3.py") + + files, err := fileset.All() + require.NoError(t, err) + + f, err := FromSlice(files) + require.NoError(t, err) + require.Len(t, f, 3) + + s := f.ToSlice(tmpDir) + require.Len(t, s, 3) + + for _, file := range s { + require.Contains(t, []string{"test1.py", "test2.py", "test3.py"}, file.Name()) + require.Contains(t, []string{ + filepath.Join(tmpDir, "test1.py"), + filepath.Join(tmpDir, "test2.py"), + filepath.Join(tmpDir, "test3.py"), + }, file.Absolute) + require.False(t, file.IsDir()) + require.NotZero(t, file.Type()) + info, err := file.Info() + require.NoError(t, err) + require.NotNil(t, info) + require.Equal(t, file.Name(), info.Name()) + } +} + +func TestIsLocalStateStale(t *testing.T) { + oldState, err := json.Marshal(DeploymentState{ + Seq: 1, + }) + require.NoError(t, err) + + newState, err := json.Marshal(DeploymentState{ + Seq: 2, + }) + require.NoError(t, err) + + require.True(t, isLocalStateStale(bytes.NewReader(oldState), bytes.NewReader(newState))) + require.False(t, isLocalStateStale(bytes.NewReader(newState), bytes.NewReader(oldState))) +} diff --git a/bundle/deploy/state_update.go b/bundle/deploy/state_update.go new file mode 100644 index 00000000..0ae61a6e --- /dev/null +++ b/bundle/deploy/state_update.go @@ -0,0 +1,108 @@ +package deploy + +import ( + "bytes" + "context" + "encoding/json" + "io" + "os" + "time" + + "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/deploy/files" + "github.com/databricks/cli/internal/build" + "github.com/databricks/cli/libs/log" +) + +type stateUpdate struct { +} + +func (s *stateUpdate) Name() string { + return "deploy:state-update" +} + +func (s *stateUpdate) Apply(ctx context.Context, b *bundle.Bundle) error { + state, err := load(ctx, b) + if err != nil { + return err + } + + // Increment the state sequence. + state.Seq = state.Seq + 1 + + // Update timestamp. + state.Timestamp = time.Now().UTC() + + // Update the CLI version and deployment state version. + state.CliVersion = build.GetInfo().Version + state.Version = DeploymentStateVersion + + // Get the current file list. + sync, err := files.GetSync(ctx, b) + if err != nil { + return err + } + + files, err := sync.GetFileList(ctx) + if err != nil { + return err + } + + // Update the state with the current file list. + fl, err := FromSlice(files) + if err != nil { + return err + } + state.Files = fl + + statePath, err := getPathToStateFile(ctx, b) + if err != nil { + return err + } + // Write the state back to the file. + f, err := os.OpenFile(statePath, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0600) + if err != nil { + log.Infof(ctx, "Unable to open deployment state file: %s", err) + return err + } + defer f.Close() + + data, err := json.Marshal(state) + if err != nil { + return err + } + + _, err = io.Copy(f, bytes.NewReader(data)) + if err != nil { + return err + } + + return nil +} + +func StateUpdate() bundle.Mutator { + return &stateUpdate{} +} + +func load(ctx context.Context, b *bundle.Bundle) (*DeploymentState, error) { + // If the file does not exist, return a new DeploymentState. + statePath, err := getPathToStateFile(ctx, b) + if err != nil { + return nil, err + } + + log.Infof(ctx, "Loading deployment state from %s", statePath) + f, err := os.Open(statePath) + if err != nil { + if os.IsNotExist(err) { + log.Infof(ctx, "No deployment state file found") + return &DeploymentState{ + Version: DeploymentStateVersion, + CliVersion: build.GetInfo().Version, + }, nil + } + return nil, err + } + defer f.Close() + return loadState(f) +} diff --git a/bundle/deploy/state_update_test.go b/bundle/deploy/state_update_test.go new file mode 100644 index 00000000..5e16dd00 --- /dev/null +++ b/bundle/deploy/state_update_test.go @@ -0,0 +1,149 @@ +package deploy + +import ( + "context" + "encoding/json" + "os" + "testing" + + "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/internal/build" + "github.com/databricks/cli/internal/testutil" + databrickscfg "github.com/databricks/databricks-sdk-go/config" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/iam" + "github.com/databricks/databricks-sdk-go/service/workspace" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestStateUpdate(t *testing.T) { + s := &stateUpdate{} + + b := &bundle.Bundle{ + Config: config.Root{ + Path: t.TempDir(), + Bundle: config.Bundle{ + Target: "default", + }, + Workspace: config.Workspace{ + StatePath: "/state", + FilePath: "/files", + CurrentUser: &config.User{ + User: &iam.User{ + UserName: "test-user", + }, + }, + }, + }, + } + + testutil.Touch(t, b.Config.Path, "test1.py") + testutil.Touch(t, b.Config.Path, "test2.py") + + m := mocks.NewMockWorkspaceClient(t) + m.WorkspaceClient.Config = &databrickscfg.Config{ + Host: "https://test.com", + } + b.SetWorkpaceClient(m.WorkspaceClient) + + wsApi := m.GetMockWorkspaceAPI() + wsApi.EXPECT().GetStatusByPath(mock.Anything, "/files").Return(&workspace.ObjectInfo{ + ObjectType: "DIRECTORY", + }, nil) + + ctx := context.Background() + + err := bundle.Apply(ctx, b, s) + require.NoError(t, err) + + // Check that the state file was updated. + state, err := load(ctx, b) + require.NoError(t, err) + + require.Equal(t, int64(1), state.Seq) + require.Len(t, state.Files, 3) + require.Equal(t, build.GetInfo().Version, state.CliVersion) + + err = bundle.Apply(ctx, b, s) + require.NoError(t, err) + + // Check that the state file was updated again. + state, err = load(ctx, b) + require.NoError(t, err) + + require.Equal(t, int64(2), state.Seq) + require.Len(t, state.Files, 3) + require.Equal(t, build.GetInfo().Version, state.CliVersion) +} + +func TestStateUpdateWithExistingState(t *testing.T) { + s := &stateUpdate{} + + b := &bundle.Bundle{ + Config: config.Root{ + Path: t.TempDir(), + Bundle: config.Bundle{ + Target: "default", + }, + Workspace: config.Workspace{ + StatePath: "/state", + FilePath: "/files", + CurrentUser: &config.User{ + User: &iam.User{ + UserName: "test-user", + }, + }, + }, + }, + } + + testutil.Touch(t, b.Config.Path, "test1.py") + testutil.Touch(t, b.Config.Path, "test2.py") + + m := mocks.NewMockWorkspaceClient(t) + m.WorkspaceClient.Config = &databrickscfg.Config{ + Host: "https://test.com", + } + b.SetWorkpaceClient(m.WorkspaceClient) + + wsApi := m.GetMockWorkspaceAPI() + wsApi.EXPECT().GetStatusByPath(mock.Anything, "/files").Return(&workspace.ObjectInfo{ + ObjectType: "DIRECTORY", + }, nil) + + ctx := context.Background() + + // Create an existing state file. + statePath, err := getPathToStateFile(ctx, b) + require.NoError(t, err) + + state := &DeploymentState{ + Version: DeploymentStateVersion, + Seq: 10, + CliVersion: build.GetInfo().Version, + Files: []File{ + { + LocalPath: "bar/t1.py", + }, + }, + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + err = os.WriteFile(statePath, data, 0644) + require.NoError(t, err) + + err = bundle.Apply(ctx, b, s) + require.NoError(t, err) + + // Check that the state file was updated. + state, err = load(ctx, b) + require.NoError(t, err) + + require.Equal(t, int64(11), state.Seq) + require.Len(t, state.Files, 3) + require.Equal(t, build.GetInfo().Version, state.CliVersion) +} diff --git a/bundle/deploy/terraform/filer.go b/bundle/deploy/terraform/filer.go deleted file mode 100644 index b1fa5a1b..00000000 --- a/bundle/deploy/terraform/filer.go +++ /dev/null @@ -1,14 +0,0 @@ -package terraform - -import ( - "github.com/databricks/cli/bundle" - "github.com/databricks/cli/libs/filer" -) - -// filerFunc is a function that returns a filer.Filer. -type filerFunc func(b *bundle.Bundle) (filer.Filer, error) - -// stateFiler returns a filer.Filer that can be used to read/write state files. -func stateFiler(b *bundle.Bundle) (filer.Filer, error) { - return filer.NewWorkspaceFilesClient(b.WorkspaceClient(), b.Config.Workspace.StatePath) -} diff --git a/bundle/deploy/terraform/state_pull.go b/bundle/deploy/terraform/state_pull.go index 14e8ecf1..045222ae 100644 --- a/bundle/deploy/terraform/state_pull.go +++ b/bundle/deploy/terraform/state_pull.go @@ -10,12 +10,13 @@ import ( "path/filepath" "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/deploy" "github.com/databricks/cli/libs/filer" "github.com/databricks/cli/libs/log" ) type statePull struct { - filerFunc + filerFactory deploy.FilerFactory } func (l *statePull) Name() string { @@ -45,7 +46,7 @@ func (l *statePull) remoteState(ctx context.Context, f filer.Filer) (*bytes.Buff } func (l *statePull) Apply(ctx context.Context, b *bundle.Bundle) error { - f, err := l.filerFunc(b) + f, err := l.filerFactory(b) if err != nil { return err } @@ -94,5 +95,5 @@ func (l *statePull) Apply(ctx context.Context, b *bundle.Bundle) error { } func StatePull() bundle.Mutator { - return &statePull{stateFiler} + return &statePull{deploy.StateFiler} } diff --git a/bundle/deploy/terraform/state_push.go b/bundle/deploy/terraform/state_push.go index a5140329..f701db87 100644 --- a/bundle/deploy/terraform/state_push.go +++ b/bundle/deploy/terraform/state_push.go @@ -6,13 +6,14 @@ import ( "path/filepath" "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/deploy" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/filer" "github.com/databricks/cli/libs/log" ) type statePush struct { - filerFunc + filerFactory deploy.FilerFactory } func (l *statePush) Name() string { @@ -20,7 +21,7 @@ func (l *statePush) Name() string { } func (l *statePush) Apply(ctx context.Context, b *bundle.Bundle) error { - f, err := l.filerFunc(b) + f, err := l.filerFactory(b) if err != nil { return err } @@ -49,5 +50,5 @@ func (l *statePush) Apply(ctx context.Context, b *bundle.Bundle) error { } func StatePush() bundle.Mutator { - return &statePush{stateFiler} + return &statePush{deploy.StateFiler} } diff --git a/bundle/deploy/terraform/state_test.go b/bundle/deploy/terraform/state_test.go index ee15b953..ff325062 100644 --- a/bundle/deploy/terraform/state_test.go +++ b/bundle/deploy/terraform/state_test.go @@ -8,12 +8,13 @@ import ( "testing" "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/deploy" "github.com/databricks/cli/libs/filer" "github.com/stretchr/testify/require" ) -// identityFiler returns a filerFunc that returns the specified filer. -func identityFiler(f filer.Filer) filerFunc { +// identityFiler returns a FilerFactory that returns the specified filer. +func identityFiler(f filer.Filer) deploy.FilerFactory { return func(_ *bundle.Bundle) (filer.Filer, error) { return f, nil } diff --git a/bundle/phases/deploy.go b/bundle/phases/deploy.go index 5c657550..f266a98f 100644 --- a/bundle/phases/deploy.go +++ b/bundle/phases/deploy.go @@ -24,6 +24,7 @@ func Deploy() bundle.Mutator { bundle.Defer( bundle.Seq( terraform.StatePull(), + deploy.StatePull(), deploy.CheckRunningResource(), mutator.ValidateGitDetails(), libraries.MatchWithArtifacts(), @@ -31,6 +32,7 @@ func Deploy() bundle.Mutator { artifacts.UploadAll(), python.TransformWheelTask(), files.Upload(), + deploy.StateUpdate(), permissions.ApplyWorkspaceRootPermissions(), terraform.Interpolate(), terraform.Write(), @@ -38,6 +40,7 @@ func Deploy() bundle.Mutator { terraform.Apply(), bundle.Seq( terraform.StatePush(), + deploy.StatePush(), terraform.Load(), metadata.Compute(), metadata.Upload(), diff --git a/cmd/bundle/sync.go b/cmd/bundle/sync.go index 20ec2fcd..0b7ab447 100644 --- a/cmd/bundle/sync.go +++ b/cmd/bundle/sync.go @@ -5,6 +5,7 @@ import ( "time" "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/deploy/files" "github.com/databricks/cli/bundle/phases" "github.com/databricks/cli/cmd/bundle/utils" "github.com/databricks/cli/cmd/root" @@ -20,28 +21,14 @@ type syncFlags struct { } func (f *syncFlags) syncOptionsFromBundle(cmd *cobra.Command, b *bundle.Bundle) (*sync.SyncOptions, error) { - cacheDir, err := b.CacheDir(cmd.Context()) + opts, err := files.GetSyncOptions(cmd.Context(), b) if err != nil { - return nil, fmt.Errorf("cannot get bundle cache directory: %w", err) + return nil, fmt.Errorf("cannot get sync options: %w", err) } - includes, err := b.GetSyncIncludePatterns(cmd.Context()) - if err != nil { - return nil, fmt.Errorf("cannot get list of sync includes: %w", err) - } - - opts := sync.SyncOptions{ - LocalPath: b.Config.Path, - RemotePath: b.Config.Workspace.FilePath, - Include: includes, - Exclude: b.Config.Sync.Exclude, - Full: f.full, - PollInterval: f.interval, - - SnapshotBasePath: cacheDir, - WorkspaceClient: b.WorkspaceClient(), - } - return &opts, nil + opts.Full = f.full + opts.PollInterval = f.interval + return opts, nil } func newSyncCommand() *cobra.Command { diff --git a/cmd/sync/sync.go b/cmd/sync/sync.go index f08d3d61..6899d6fe 100644 --- a/cmd/sync/sync.go +++ b/cmd/sync/sync.go @@ -10,6 +10,7 @@ import ( "time" "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/deploy/files" "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/sync" @@ -29,28 +30,14 @@ func (f *syncFlags) syncOptionsFromBundle(cmd *cobra.Command, args []string, b * return nil, fmt.Errorf("SRC and DST are not configurable in the context of a bundle") } - cacheDir, err := b.CacheDir(cmd.Context()) + opts, err := files.GetSyncOptions(cmd.Context(), b) if err != nil { - return nil, fmt.Errorf("cannot get bundle cache directory: %w", err) + return nil, fmt.Errorf("cannot get sync options: %w", err) } - includes, err := b.GetSyncIncludePatterns(cmd.Context()) - if err != nil { - return nil, fmt.Errorf("cannot get list of sync includes: %w", err) - } - - opts := sync.SyncOptions{ - LocalPath: b.Config.Path, - RemotePath: b.Config.Workspace.FilePath, - Include: includes, - Exclude: b.Config.Sync.Exclude, - Full: f.full, - PollInterval: f.interval, - - SnapshotBasePath: cacheDir, - WorkspaceClient: b.WorkspaceClient(), - } - return &opts, nil + opts.Full = f.full + opts.PollInterval = f.interval + return opts, nil } func (f *syncFlags) syncOptionsFromArgs(cmd *cobra.Command, args []string) (*sync.SyncOptions, error) { diff --git a/internal/bundle/deployment_state_test.go b/internal/bundle/deployment_state_test.go new file mode 100644 index 00000000..25f36d4a --- /dev/null +++ b/internal/bundle/deployment_state_test.go @@ -0,0 +1,102 @@ +package bundle + +import ( + "os" + "path" + "path/filepath" + "testing" + + "github.com/databricks/cli/bundle/deploy" + "github.com/databricks/cli/internal" + "github.com/databricks/cli/internal/acc" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func TestAccFilesAreSyncedCorrectlyWhenNoSnapshot(t *testing.T) { + env := internal.GetEnvOrSkipTest(t, "CLOUD_ENV") + t.Log(env) + + ctx, wt := acc.WorkspaceTest(t) + w := wt.W + + nodeTypeId := internal.GetNodeTypeId(env) + uniqueId := uuid.New().String() + bundleRoot, err := initTestTemplate(t, ctx, "basic", map[string]any{ + "unique_id": uniqueId, + "spark_version": "13.3.x-scala2.12", + "node_type_id": nodeTypeId, + }) + require.NoError(t, err) + + t.Setenv("BUNDLE_ROOT", bundleRoot) + + // Add some test file to the bundle + err = os.WriteFile(filepath.Join(bundleRoot, "test.py"), []byte("print('Hello, World!')"), 0644) + require.NoError(t, err) + + err = os.WriteFile(filepath.Join(bundleRoot, "test_to_modify.py"), []byte("print('Hello, World!')"), 0644) + require.NoError(t, err) + + // Add notebook to the bundle + err = os.WriteFile(filepath.Join(bundleRoot, "notebook.py"), []byte("# Databricks notebook source\nHello, World!"), 0644) + require.NoError(t, err) + + err = deployBundle(t, ctx, bundleRoot) + require.NoError(t, err) + + t.Cleanup(func() { + destroyBundle(t, ctx, bundleRoot) + }) + + remoteRoot := getBundleRemoteRootPath(w, t, uniqueId) + + // Check that test file is in workspace + _, err = w.Workspace.GetStatusByPath(ctx, path.Join(remoteRoot, "files", "test.py")) + require.NoError(t, err) + + _, err = w.Workspace.GetStatusByPath(ctx, path.Join(remoteRoot, "files", "test_to_modify.py")) + require.NoError(t, err) + + // Check that notebook is in workspace + _, err = w.Workspace.GetStatusByPath(ctx, path.Join(remoteRoot, "files", "notebook")) + require.NoError(t, err) + + // Check that deployment.json is synced correctly + _, err = w.Workspace.GetStatusByPath(ctx, path.Join(remoteRoot, "state", deploy.DeploymentStateFileName)) + require.NoError(t, err) + + // Remove .databricks directory to simulate a fresh deployment like in CI/CD environment + err = os.RemoveAll(filepath.Join(bundleRoot, ".databricks")) + require.NoError(t, err) + + // Remove the file from the bundle + err = os.Remove(filepath.Join(bundleRoot, "test.py")) + require.NoError(t, err) + + // Remove the notebook from the bundle and deploy again + err = os.Remove(filepath.Join(bundleRoot, "notebook.py")) + require.NoError(t, err) + + // Modify the content of another file + err = os.WriteFile(filepath.Join(bundleRoot, "test_to_modify.py"), []byte("print('Modified!')"), 0644) + require.NoError(t, err) + + err = deployBundle(t, ctx, bundleRoot) + require.NoError(t, err) + + // Check that removed file is not in workspace anymore + _, err = w.Workspace.GetStatusByPath(ctx, path.Join(remoteRoot, "files", "test.py")) + require.ErrorContains(t, err, "files/test.py") + require.ErrorContains(t, err, "doesn't exist") + + // Check that removed notebook is not in workspace anymore + _, err = w.Workspace.GetStatusByPath(ctx, path.Join(remoteRoot, "files", "notebook")) + require.ErrorContains(t, err, "files/notebook") + require.ErrorContains(t, err, "doesn't exist") + + // Check the content of modified file + content, err := w.Workspace.ReadFile(ctx, path.Join(remoteRoot, "files", "test_to_modify.py")) + require.NoError(t, err) + require.Equal(t, "print('Modified!')", string(content)) +} diff --git a/internal/bundle/helpers.go b/internal/bundle/helpers.go index a8fbd230..10e315bd 100644 --- a/internal/bundle/helpers.go +++ b/internal/bundle/helpers.go @@ -3,6 +3,7 @@ package bundle import ( "context" "encoding/json" + "fmt" "os" "path/filepath" "strings" @@ -13,6 +14,8 @@ import ( "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/template" + "github.com/databricks/databricks-sdk-go" + "github.com/stretchr/testify/require" ) func initTestTemplate(t *testing.T, ctx context.Context, templateName string, config map[string]any) (string, error) { @@ -78,3 +81,11 @@ func destroyBundle(t *testing.T, ctx context.Context, path string) error { _, _, err := c.Run() return err } + +func getBundleRemoteRootPath(w *databricks.WorkspaceClient, t *testing.T, uniqueId string) string { + // Compute root path for the bundle deployment + me, err := w.CurrentUser.Me(context.Background()) + require.NoError(t, err) + root := fmt.Sprintf("/Users/%s/.bundle/%s", me.UserName, uniqueId) + return root +} diff --git a/internal/bundle/job_metadata_test.go b/internal/bundle/job_metadata_test.go index 0d8a431e..cb3ad081 100644 --- a/internal/bundle/job_metadata_test.go +++ b/internal/bundle/job_metadata_test.go @@ -3,7 +3,6 @@ package bundle import ( "context" "encoding/json" - "fmt" "io" "path" "strconv" @@ -56,9 +55,7 @@ func TestAccJobsMetadataFile(t *testing.T) { assert.Equal(t, job2.Settings.Name, jobName) // Compute root path for the bundle deployment - me, err := w.CurrentUser.Me(context.Background()) - require.NoError(t, err) - root := fmt.Sprintf("/Users/%s/.bundle/%s", me.UserName, uniqueId) + root := getBundleRemoteRootPath(w, t, uniqueId) f, err := filer.NewWorkspaceFilesClient(w, root) require.NoError(t, err) diff --git a/internal/testutil/helpers.go b/internal/testutil/helpers.go new file mode 100644 index 00000000..853cc16c --- /dev/null +++ b/internal/testutil/helpers.go @@ -0,0 +1,26 @@ +package testutil + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TouchNotebook(t *testing.T, path, file string) { + os.MkdirAll(path, 0755) + f, err := os.Create(filepath.Join(path, file)) + require.NoError(t, err) + + err = os.WriteFile(filepath.Join(path, file), []byte("# Databricks notebook source"), 0644) + require.NoError(t, err) + f.Close() +} + +func Touch(t *testing.T, path, file string) { + os.MkdirAll(path, 0755) + f, err := os.Create(filepath.Join(path, file)) + require.NoError(t, err) + f.Close() +} diff --git a/libs/fileset/file.go b/libs/fileset/file.go index 6594de4e..17cae795 100644 --- a/libs/fileset/file.go +++ b/libs/fileset/file.go @@ -3,11 +3,49 @@ package fileset import ( "io/fs" "time" + + "github.com/databricks/cli/libs/notebook" +) + +type fileType int + +const ( + Unknown fileType = iota + Notebook // Databricks notebook file + Source // Any other file type ) type File struct { fs.DirEntry Absolute, Relative string + fileType fileType +} + +func NewNotebookFile(entry fs.DirEntry, absolute string, relative string) File { + return File{ + DirEntry: entry, + Absolute: absolute, + Relative: relative, + fileType: Notebook, + } +} + +func NewFile(entry fs.DirEntry, absolute string, relative string) File { + return File{ + DirEntry: entry, + Absolute: absolute, + Relative: relative, + fileType: Unknown, + } +} + +func NewSourceFile(entry fs.DirEntry, absolute string, relative string) File { + return File{ + DirEntry: entry, + Absolute: absolute, + Relative: relative, + fileType: Source, + } } func (f File) Modified() (ts time.Time) { @@ -18,3 +56,21 @@ func (f File) Modified() (ts time.Time) { } return info.ModTime() } + +func (f *File) IsNotebook() (bool, error) { + if f.fileType != Unknown { + return f.fileType == Notebook, nil + } + + // Otherwise, detect the notebook type. + isNotebook, _, err := notebook.Detect(f.Absolute) + if err != nil { + return false, err + } + if isNotebook { + f.fileType = Notebook + } else { + f.fileType = Source + } + return isNotebook, nil +} diff --git a/libs/fileset/file_test.go b/libs/fileset/file_test.go new file mode 100644 index 00000000..4adcb1c5 --- /dev/null +++ b/libs/fileset/file_test.go @@ -0,0 +1,39 @@ +package fileset + +import ( + "path/filepath" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/stretchr/testify/require" +) + +func TestNotebookFileIsNotebook(t *testing.T) { + f := NewNotebookFile(nil, "", "") + isNotebook, err := f.IsNotebook() + require.NoError(t, err) + require.True(t, isNotebook) +} + +func TestSourceFileIsNotNotebook(t *testing.T) { + f := NewSourceFile(nil, "", "") + isNotebook, err := f.IsNotebook() + require.NoError(t, err) + require.False(t, isNotebook) +} + +func TestUnknownFileDetectsNotebook(t *testing.T) { + tmpDir := t.TempDir() + testutil.Touch(t, tmpDir, "test.py") + testutil.TouchNotebook(t, tmpDir, "notebook.py") + + f := NewFile(nil, filepath.Join(tmpDir, "test.py"), "test.py") + isNotebook, err := f.IsNotebook() + require.NoError(t, err) + require.False(t, isNotebook) + + f = NewFile(nil, filepath.Join(tmpDir, "notebook.py"), "notebook.py") + isNotebook, err = f.IsNotebook() + require.NoError(t, err) + require.True(t, isNotebook) +} diff --git a/libs/fileset/fileset.go b/libs/fileset/fileset.go index 81b85525..52463dff 100644 --- a/libs/fileset/fileset.go +++ b/libs/fileset/fileset.go @@ -84,7 +84,7 @@ func (w *FileSet) recursiveListFiles() (fileList []File, err error) { return nil } - fileList = append(fileList, File{d, path, relPath}) + fileList = append(fileList, NewFile(d, path, relPath)) return nil }) return diff --git a/libs/sync/snapshot.go b/libs/sync/snapshot.go index f9956962..06b4d13b 100644 --- a/libs/sync/snapshot.go +++ b/libs/sync/snapshot.go @@ -53,6 +53,30 @@ type Snapshot struct { const syncSnapshotDirName = "sync-snapshots" +func NewSnapshot(localFiles []fileset.File, opts *SyncOptions) (*Snapshot, error) { + snapshotPath, err := SnapshotPath(opts) + if err != nil { + return nil, err + } + + snapshotState, err := NewSnapshotState(localFiles) + if err != nil { + return nil, err + } + + // Reset last modified times to make sure all files are synced + snapshotState.ResetLastModifiedTimes() + + return &Snapshot{ + SnapshotPath: snapshotPath, + New: true, + Version: LatestSnapshotVersion, + Host: opts.Host, + RemotePath: opts.RemotePath, + SnapshotState: snapshotState, + }, nil +} + func GetFileName(host, remotePath string) string { hash := md5.Sum([]byte(host + remotePath)) hashString := hex.EncodeToString(hash[:]) diff --git a/libs/sync/snapshot_state.go b/libs/sync/snapshot_state.go index 57506352..10cd34e6 100644 --- a/libs/sync/snapshot_state.go +++ b/libs/sync/snapshot_state.go @@ -7,7 +7,6 @@ import ( "time" "github.com/databricks/cli/libs/fileset" - "github.com/databricks/cli/libs/notebook" ) // SnapshotState keeps track of files on the local filesystem and their corresponding @@ -46,10 +45,12 @@ func NewSnapshotState(localFiles []fileset.File) (*SnapshotState, error) { } // Compute the new state. - for _, f := range localFiles { + for k := range localFiles { + f := &localFiles[k] // Compute the remote name the file will have in WSFS remoteName := filepath.ToSlash(f.Relative) - isNotebook, _, err := notebook.Detect(f.Absolute) + isNotebook, err := f.IsNotebook() + if err != nil { // Ignore this file if we're unable to determine the notebook type. // Trying to upload such a file to the workspace would fail anyway. @@ -72,6 +73,12 @@ func NewSnapshotState(localFiles []fileset.File) (*SnapshotState, error) { return fs, nil } +func (fs *SnapshotState) ResetLastModifiedTimes() { + for k := range fs.LastModifiedTimes { + fs.LastModifiedTimes[k] = time.Unix(0, 0) + } +} + // Consistency checks for the sync files state representation. These are invariants // that downstream code for computing changes to apply to WSFS depends on. // diff --git a/libs/sync/sync.go b/libs/sync/sync.go index beb3f6a3..78faa0c8 100644 --- a/libs/sync/sync.go +++ b/libs/sync/sync.go @@ -151,7 +151,7 @@ func (s *Sync) notifyComplete(ctx context.Context, d diff) { } func (s *Sync) RunOnce(ctx context.Context) error { - files, err := getFileList(ctx, s) + files, err := s.GetFileList(ctx) if err != nil { return err } @@ -182,7 +182,7 @@ func (s *Sync) RunOnce(ctx context.Context) error { return nil } -func getFileList(ctx context.Context, s *Sync) ([]fileset.File, error) { +func (s *Sync) GetFileList(ctx context.Context) ([]fileset.File, error) { // tradeoff: doing portable monitoring only due to macOS max descriptor manual ulimit setting requirement // https://github.com/gorakhargosh/watchdog/blob/master/src/watchdog/observers/kqueue.py#L394-L418 all := set.NewSetF(func(f fileset.File) string { diff --git a/libs/sync/sync_test.go b/libs/sync/sync_test.go index 0f1ad61b..dc220dbf 100644 --- a/libs/sync/sync_test.go +++ b/libs/sync/sync_test.go @@ -93,7 +93,7 @@ func TestGetFileSet(t *testing.T) { excludeFileSet: excl, } - fileList, err := getFileList(ctx, s) + fileList, err := s.GetFileList(ctx) require.NoError(t, err) require.Equal(t, len(fileList), 9) @@ -111,7 +111,7 @@ func TestGetFileSet(t *testing.T) { excludeFileSet: excl, } - fileList, err = getFileList(ctx, s) + fileList, err = s.GetFileList(ctx) require.NoError(t, err) require.Equal(t, len(fileList), 1) @@ -129,7 +129,7 @@ func TestGetFileSet(t *testing.T) { excludeFileSet: excl, } - fileList, err = getFileList(ctx, s) + fileList, err = s.GetFileList(ctx) require.NoError(t, err) require.Equal(t, len(fileList), 10) } @@ -158,7 +158,7 @@ func TestRecursiveExclude(t *testing.T) { excludeFileSet: excl, } - fileList, err := getFileList(ctx, s) + fileList, err := s.GetFileList(ctx) require.NoError(t, err) require.Equal(t, len(fileList), 7) }