diff --git a/bundle/deployer/deployer.go b/bundle/deployer/deployer.go index 7a8bb01fc..1d780f721 100644 --- a/bundle/deployer/deployer.go +++ b/bundle/deployer/deployer.go @@ -2,10 +2,12 @@ package deployer import ( "context" + "errors" "fmt" + "io" + "io/fs" "os" "path/filepath" - "strings" "github.com/databricks/cli/libs/locker" "github.com/databricks/cli/libs/log" @@ -97,22 +99,24 @@ func (b *Deployer) tfStateLocalPath() string { return filepath.Join(b.DefaultTerraformRoot(), "terraform.tfstate") } -func (b *Deployer) LoadTerraformState(ctx context.Context) error { - bytes, err := b.locker.GetRawJsonFileContent(ctx, b.tfStateRemotePath()) - if err != nil { +func (d *Deployer) LoadTerraformState(ctx context.Context) error { + r, err := d.locker.Read(ctx, d.tfStateRemotePath()) + if errors.Is(err, fs.ErrNotExist) { // If remote tf state is absent, use local tf state - if strings.Contains(err.Error(), "File not found.") { - return nil - } else { - return err - } + return nil } - err = os.MkdirAll(b.DefaultTerraformRoot(), os.ModeDir) if err != nil { return err } - err = os.WriteFile(b.tfStateLocalPath(), bytes, os.ModePerm) - return err + err = os.MkdirAll(d.DefaultTerraformRoot(), os.ModeDir) + if err != nil { + return err + } + b, err := io.ReadAll(r) + if err != nil { + return err + } + return os.WriteFile(d.tfStateLocalPath(), b, os.ModePerm) } func (b *Deployer) SaveTerraformState(ctx context.Context) error { @@ -120,7 +124,7 @@ func (b *Deployer) SaveTerraformState(ctx context.Context) error { if err != nil { return err } - return b.locker.PutFile(ctx, b.tfStateRemotePath(), bytes) + return b.locker.Write(ctx, b.tfStateRemotePath(), bytes) } func (d *Deployer) Lock(ctx context.Context, isForced bool) error { diff --git a/internal/locker_test.go b/internal/locker_test.go index bc26bdaa4..f3e026d62 100644 --- a/internal/locker_test.go +++ b/internal/locker_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "io/fs" "math/rand" "sync" @@ -114,18 +115,23 @@ func TestAccLock(t *testing.T) { if i == indexOfActiveLocker { continue } - err := lockers[i].PutFile(ctx, "foo.json", []byte(`'{"surname":"Khan", "name":"Shah Rukh"}`)) + err := lockers[i].Write(ctx, "foo.json", []byte(`'{"surname":"Khan", "name":"Shah Rukh"}`)) assert.ErrorContains(t, err, "failed to put file. deploy lock not held") } // active locker file write succeeds - err = lockers[indexOfActiveLocker].PutFile(ctx, "foo.json", []byte(`{"surname":"Khan", "name":"Shah Rukh"}`)) + err = lockers[indexOfActiveLocker].Write(ctx, "foo.json", []byte(`{"surname":"Khan", "name":"Shah Rukh"}`)) assert.NoError(t, err) - // active locker file read succeeds with expected results - bytes, err := lockers[indexOfActiveLocker].GetRawJsonFileContent(ctx, "foo.json") + // read active locker file + r, err := lockers[indexOfActiveLocker].Read(ctx, "foo.json") + require.NoError(t, err) + b, err := io.ReadAll(r) + require.NoError(t, err) + + // assert on active locker content var res map[string]string - json.Unmarshal(bytes, &res) + json.Unmarshal(b, &res) assert.NoError(t, err) assert.Equal(t, "Khan", res["surname"]) assert.Equal(t, "Shah Rukh", res["name"]) @@ -135,7 +141,7 @@ func TestAccLock(t *testing.T) { if i == indexOfActiveLocker { continue } - _, err = lockers[i].GetRawJsonFileContent(ctx, "foo.json") + _, err = lockers[i].Read(ctx, "foo.json") assert.ErrorContains(t, err, "failed to get file. deploy lock not held") } diff --git a/libs/locker/locker.go b/libs/locker/locker.go index 3b7725d9e..7c23f40e7 100644 --- a/libs/locker/locker.go +++ b/libs/locker/locker.go @@ -7,7 +7,7 @@ import ( "errors" "fmt" "io" - "strings" + "io/fs" "time" "github.com/databricks/cli/libs/filer" @@ -88,7 +88,7 @@ func (locker *Locker) GetActiveLockState(ctx context.Context) (*LockState, error // holder details if locker does not hold the lock func (locker *Locker) assertLockHeld(ctx context.Context) error { activeLockState, err := locker.GetActiveLockState(ctx) - if err != nil && strings.Contains(err.Error(), "File not found.") { + if errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("no active lock on target dir: %s", err) } if err != nil { @@ -104,22 +104,18 @@ func (locker *Locker) assertLockHeld(ctx context.Context) error { } // idempotent function since overwrite is set to true -func (locker *Locker) PutFile(ctx context.Context, pathToFile string, content []byte) error { +func (locker *Locker) Write(ctx context.Context, pathToFile string, content []byte) error { if !locker.Active { return fmt.Errorf("failed to put file. deploy lock not held") } return locker.filer.Write(ctx, pathToFile, bytes.NewReader(content), filer.OverwriteIfExists, filer.CreateParentDirectories) } -func (locker *Locker) GetRawJsonFileContent(ctx context.Context, path string) ([]byte, error) { +func (locker *Locker) Read(ctx context.Context, path string) (io.ReadCloser, error) { if !locker.Active { return nil, fmt.Errorf("failed to get file. deploy lock not held") } - reader, err := locker.filer.Read(ctx, path) - if err != nil { - return nil, err - } - return io.ReadAll(reader) + return locker.filer.Read(ctx, path) } func (locker *Locker) Lock(ctx context.Context, isForced bool) error {