diff --git a/bundle/deploy/terraform/state_pull.go b/bundle/deploy/terraform/state_pull.go index 93ae3248..fb7162a5 100644 --- a/bundle/deploy/terraform/state_pull.go +++ b/bundle/deploy/terraform/state_pull.go @@ -42,10 +42,20 @@ func (l *statePull) Apply(ctx context.Context, b *bundle.Bundle) ([]bundle.Mutat } // Expect the state file to live under dir. - local, err := os.OpenFile(filepath.Join(dir, TerraformStateFileName), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) + local, err := os.OpenFile(filepath.Join(dir, TerraformStateFileName), os.O_CREATE|os.O_RDWR, 0600) if err != nil { return nil, err } + defer local.Close() + + if !IsLocalStateStale(local, remote) { + log.Infof(ctx, "Local state is the same or newer, ignoring remote state") + return nil, nil + } + + // Truncating the file before writing + local.Truncate(0) + local.Seek(0, 0) // Write file to disk. log.Infof(ctx, "Writing remote state file to local cache directory") diff --git a/bundle/deploy/terraform/util.go b/bundle/deploy/terraform/util.go new file mode 100644 index 00000000..a5978b39 --- /dev/null +++ b/bundle/deploy/terraform/util.go @@ -0,0 +1,38 @@ +package terraform + +import ( + "encoding/json" + "io" +) + +type state struct { + Serial int `json:"serial"` +} + +func IsLocalStateStale(local io.Reader, remote io.Reader) bool { + localState, err := loadState(local) + if err != nil { + return true + } + + remoteState, err := loadState(remote) + if err != nil { + return false + } + + return localState.Serial < remoteState.Serial +} + +func loadState(input io.Reader) (*state, error) { + content, err := io.ReadAll(input) + if err != nil { + return nil, err + } + var s state + err = json.Unmarshal(content, &s) + if err != nil { + return nil, err + } + + return &s, nil +} diff --git a/bundle/deploy/terraform/util_test.go b/bundle/deploy/terraform/util_test.go new file mode 100644 index 00000000..1ddfbab3 --- /dev/null +++ b/bundle/deploy/terraform/util_test.go @@ -0,0 +1,93 @@ +package terraform + +import ( + "fmt" + "io" + "testing" + "testing/iotest" + + "github.com/stretchr/testify/assert" +) + +type mockedReader struct { + content string +} + +func (r *mockedReader) Read(p []byte) (n int, err error) { + content := []byte(r.content) + n = copy(p, content) + return n, io.EOF +} + +func TestLocalStateIsNewer(t *testing.T) { + local := &mockedReader{content: ` +{ + "serial": 5 +} +`} + remote := &mockedReader{content: ` +{ + "serial": 4 +} +`} + + stale := IsLocalStateStale(local, remote) + + assert.False(t, stale) +} + +func TestLocalStateIsOlder(t *testing.T) { + local := &mockedReader{content: ` +{ + "serial": 5 +} +`} + remote := &mockedReader{content: ` +{ + "serial": 6 +} +`} + + stale := IsLocalStateStale(local, remote) + assert.True(t, stale) +} + +func TestLocalStateIsTheSame(t *testing.T) { + local := &mockedReader{content: ` +{ + "serial": 5 +} +`} + remote := &mockedReader{content: ` +{ + "serial": 5 +} +`} + + stale := IsLocalStateStale(local, remote) + assert.False(t, stale) +} + +func TestLocalStateMarkStaleWhenFailsToLoad(t *testing.T) { + local := iotest.ErrReader(fmt.Errorf("Random error")) + remote := &mockedReader{content: ` +{ + "serial": 5 +} +`} + + stale := IsLocalStateStale(local, remote) + assert.True(t, stale) +} + +func TestLocalStateMarkNonStaleWhenRemoteFailsToLoad(t *testing.T) { + local := &mockedReader{content: ` +{ + "serial": 5 +} +`} + remote := iotest.ErrReader(fmt.Errorf("Random error")) + + stale := IsLocalStateStale(local, remote) + assert.False(t, stale) +}