Do not truncate local state file when pulling remote changes (#382)

## Changes
When local state file exists it won't be override by remote state file

## Tests
Running `bricks bundle deploy` after state push failed does not override
local state file

Use cases verified:
1. Local state file is newer than remote
2. Local state file is older than remote
3. Local state file does not exist
4. Local state file corrupted
This commit is contained in:
Andrew Nester 2023-05-16 17:02:33 +02:00 committed by GitHub
parent 2786ec85aa
commit 33fb0b3c40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 142 additions and 1 deletions

View File

@ -42,10 +42,20 @@ func (l *statePull) Apply(ctx context.Context, b *bundle.Bundle) ([]bundle.Mutat
}
// Expect the state file to live under dir.
local, err := os.OpenFile(filepath.Join(dir, TerraformStateFileName), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
local, err := os.OpenFile(filepath.Join(dir, TerraformStateFileName), os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
return nil, err
}
defer local.Close()
if !IsLocalStateStale(local, remote) {
log.Infof(ctx, "Local state is the same or newer, ignoring remote state")
return nil, nil
}
// Truncating the file before writing
local.Truncate(0)
local.Seek(0, 0)
// Write file to disk.
log.Infof(ctx, "Writing remote state file to local cache directory")

View File

@ -0,0 +1,38 @@
package terraform
import (
"encoding/json"
"io"
)
type state struct {
Serial int `json:"serial"`
}
func IsLocalStateStale(local io.Reader, remote io.Reader) bool {
localState, err := loadState(local)
if err != nil {
return true
}
remoteState, err := loadState(remote)
if err != nil {
return false
}
return localState.Serial < remoteState.Serial
}
func loadState(input io.Reader) (*state, error) {
content, err := io.ReadAll(input)
if err != nil {
return nil, err
}
var s state
err = json.Unmarshal(content, &s)
if err != nil {
return nil, err
}
return &s, nil
}

View File

@ -0,0 +1,93 @@
package terraform
import (
"fmt"
"io"
"testing"
"testing/iotest"
"github.com/stretchr/testify/assert"
)
type mockedReader struct {
content string
}
func (r *mockedReader) Read(p []byte) (n int, err error) {
content := []byte(r.content)
n = copy(p, content)
return n, io.EOF
}
func TestLocalStateIsNewer(t *testing.T) {
local := &mockedReader{content: `
{
"serial": 5
}
`}
remote := &mockedReader{content: `
{
"serial": 4
}
`}
stale := IsLocalStateStale(local, remote)
assert.False(t, stale)
}
func TestLocalStateIsOlder(t *testing.T) {
local := &mockedReader{content: `
{
"serial": 5
}
`}
remote := &mockedReader{content: `
{
"serial": 6
}
`}
stale := IsLocalStateStale(local, remote)
assert.True(t, stale)
}
func TestLocalStateIsTheSame(t *testing.T) {
local := &mockedReader{content: `
{
"serial": 5
}
`}
remote := &mockedReader{content: `
{
"serial": 5
}
`}
stale := IsLocalStateStale(local, remote)
assert.False(t, stale)
}
func TestLocalStateMarkStaleWhenFailsToLoad(t *testing.T) {
local := iotest.ErrReader(fmt.Errorf("Random error"))
remote := &mockedReader{content: `
{
"serial": 5
}
`}
stale := IsLocalStateStale(local, remote)
assert.True(t, stale)
}
func TestLocalStateMarkNonStaleWhenRemoteFailsToLoad(t *testing.T) {
local := &mockedReader{content: `
{
"serial": 5
}
`}
remote := iotest.ErrReader(fmt.Errorf("Random error"))
stale := IsLocalStateStale(local, remote)
assert.False(t, stale)
}